diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..615ab6361e2ce540820e8cb5f6a5926837b0ebaa --- /dev/null +++ b/.gitignore @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Misc +outputs/ +checkpoints/* +!checkpoints/README.md + +# Data types +*.jit +*.pt +*.hdr +*.webp +*.pgm +*.tiff +*.tif +*.tar +*.tar.gz +*.gz +*.pkl +*.pt +*.bin + +# Other uncheckable file types +*.zip +*.exe +*.dll +*.swp +*.vscode +*.ipynb +*.DS_Store +*.pyc +*Thumbs.db +*.patch + +# Credential information that should never be checked in +credentials +*.secret + +# ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------ + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +results/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.config +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Third party +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# ruff +.ruff_cache + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +CLIP +.devcontainer/devcontainer.json + +# Coverage +.coverage +coverage.xml + +# JUnit Reports +report.xml + +# CI-CD +temp/ +envs.txt +manifest.json + + +# locks and t5 temp files +*.locks* +*.no_exist* +*models--t5* + +# OneLogger +wandb/ +onelogger.err +onelogger.log diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c9ad980682246bd6ab0d2bae82232be6dbdcbd4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e822c9a929cf249d5caed1bd014014f5565874af --- /dev/null +++ b/README.md @@ -0,0 +1,97 @@ +## How to Use + +```python +from transformers import AutoModel + +model = AutoModel.from_pretrained( + "EthanZyh/DiffusionText2WorldGeneration", + cache_dir="./cache", + trust_remote_code=True, + # turn on offloading on a low GPU memory machine: + # offload_network=True, + # offload_tokenizer=True, + # offload_text_encoder_model=True, + # offload_prompt_upsampler=True, + # offload_guardrail_models=True, +) +prompt = "Some text prompt to generate a video" +model(prompt) +``` + +![Cosmos Logo](https://github.com/NVIDIA/Cosmos/raw/main/assets/cosmos-logo.png) + +-------------------------------------------------------------------------------- +### [Website](https://www.nvidia.com/en-us/ai/cosmos/) | [HuggingFace](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) | [GPU-free Preview](https://build.nvidia.com/explore/discover) | [Paper](https://arxiv.org/abs/2501.03575) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos1/) + +[NVIDIA Cosmos](https://www.nvidia.com/cosmos/) is a developer-first world foundation model platform designed to help Physical AI developers build their Physical AI systems better and faster. Cosmos contains + +1. pre-trained models, available via [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) that allows commercial use of the models for free +2. training scripts under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0), offered through [NVIDIA Nemo Framework](https://github.com/NVIDIA/NeMo) for post-training the models for various downstream Physical AI applications + +Details of the platform is described in the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai). Preview access is avaiable at [build.nvidia.com](https://build.nvidia.com). + +## Key Features + +- [Pre-trained Diffusion-based world foundation models](cosmos1/models/diffusion/README.md) for Text2World and Video2World generation where a user can generate visual simulation based on text prompts and video prompts. +- [Pre-trained Autoregressive-based world foundation models](cosmos1/models/autoregressive/README.md) for Video2World generation where a user can generate visual simulation based on video prompts and optional text prompts. +- [Video tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer) for tokenizing videos into continuous tokens (latent vectors) and discrete tokens (integers) efficiently and effectively. +- Video curation pipeline for building your own video dataset. [Coming soon] +- [Post-training scripts](cosmos1/models/POST_TRAINING.md) via NeMo Framework to post-train the pre-trained world foundation models for various Physical AI setup. +- Pre-training scripts via NeMo Framework for building your own world foundation model. [[Diffusion](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion)] [[Autoregressive](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/multimodal_autoregressive)] [[Tokenizer](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion/vae)]. + +## Model Family + +| Model name | Description | Try it out | +|------------|----------|----------| +| [Cosmos-1.0-Diffusion-7B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) | +| [Cosmos-1.0-Diffusion-14B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) | +| [Cosmos-1.0-Diffusion-7B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) | +| [Cosmos-1.0-Diffusion-14B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) | +| [Cosmos-1.0-Autoregressive-4B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-4B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) | +| [Cosmos-1.0-Autoregressive-12B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-12B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) | +| [Cosmos-1.0-Autoregressive-5B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-5B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) | +| [Cosmos-1.0-Autoregressive-13B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-13B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) | +| [Cosmos-1.0-Guardrail](https://huggingface.co/nvidia/Cosmos-1.0-Guardrail) | Guardrail contains pre-Guard and post-Guard for safe use | Embedded in model inference scripts | + +## Example Usage + +### Inference + +Follow the [Cosmos Installation Guide](INSTALL.md) to setup the docker. For inference with the pretrained models, please refer to [Cosmos Diffusion Inference](cosmos1/models/diffusion/README.md) and [Cosmos Autoregressive Inference](cosmos1/models/autoregressive/README.md). + +The code snippet below provides a gist of the inference usage. + +```bash +PROMPT="A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. \ +The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. \ +A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, \ +suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. \ +The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of \ +field that keeps the focus on the robot while subtly blurring the background for a cinematic effect." + +# Example using 7B model +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \ + --prompt "$PROMPT" \ + --offload_prompt_upsampler \ + --video_save_name Cosmos-1.0-Diffusion-7B-Text2World +``` + + + +We also offer [multi-GPU inference](cosmos1/models/diffusion/nemo/inference/README.md) support for Diffusion Text2World WFM models through NeMo Framework. + +### Post-training + +NeMo Framework provides GPU accelerated post-training with general post-training for both [diffusion](cosmos1/models/diffusion/nemo/post_training/README.md) and [autoregressive](cosmos1/models/autoregressive/nemo/post_training/README.md) models, with other types of post-training coming soon. + +## License and Contact + +This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use. + +NVIDIA Cosmos source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0). + +NVIDIA Cosmos models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please contact [cosmos-license@nvidia.com](mailto:cosmos-license@nvidia.com). diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..99fc9ffb792085eb97a1acc7730e6e6628aabe52 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +# Release Cadence + + +| Version | Description | Date | +|------------|----------|----------| +| [v1.0](release_notes/v0p1.md) | Initial diffusion and autoregressive WFMs release | 2025-01-06 | +| [v0.1](release_notes/v0p1.md) | Initial tokenizer release | 2024-11-06 | diff --git a/aegis.py b/aegis.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f420a42ba7de5bc522c2fc1c015006d3edee9a --- /dev/null +++ b/aegis.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from .log import log +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .categories import UNSAFE_CATEGORIES +from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner +from .misc import misc, Color, timer + +SAFE = Color.green("SAFE") +UNSAFE = Color.red("UNSAFE") + +DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis" + + +class Aegis(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = torch.bfloat16 + base_model_id = "meta-llama/LlamaGuard-7b" + aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" + base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir) + self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir) + self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=self.checkpoint_dir) + self.model.to(self.device, dtype=self.dtype).eval() + + def get_moderation_prompt(self, user_prompt: str) -> str: + """Create the moderation prompt for the Aegis model.""" + full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. + + + {"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])} + + + + + """ + full_prompt += f"User: {user_prompt}\n\n" + full_prompt += """ + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" + + return full_prompt + + def get_aegis_block_message(self, moderation_output: str) -> str: + """Extract the blocked category and reason from the Aegis model output.""" + block_msg = "Prompt blocked by Aegis." + try: + # Extract blocked category and reason + if len(moderation_output.split("\n")) >= 2: + moderation_output_list = moderation_output.split("\n") + blocked_category = int(moderation_output_list[1].lstrip("O")) - 1 + if 0 <= blocked_category < len(UNSAFE_CATEGORIES): + blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"] + blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"] + block_msg = f"{blocked_category_name}: {blocked_reason}" + except Exception as e: + log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}") + return block_msg + + def filter_aegis_output(self, prompt: str) -> tuple[bool, str]: + """Filter the Aegis model output and return the safety status and message.""" + full_prompt = self.get_moderation_prompt(prompt) + inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device) + output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id) + prompt_len = inputs["input_ids"].shape[-1] + moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) + + if "unsafe" in moderation_output.lower(): + block_msg = self.get_aegis_block_message(moderation_output) + return False, block_msg + else: + return True, "" + + def is_safe(self, prompt: str) -> tuple[bool, str]: + """Check if the input prompt is safe according to the Aegis model.""" + try: + return self.filter_aegis_output(prompt) + except Exception as e: + log.error(f"Unexpected error occurred when running Aegis guardrail: {e}") + return True, "Unexpected error occurred when running Aegis guardrail." + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Aegis checkpoint folder", + default=DEFAULT_CHECKPOINT_DIR, + ) + return parser.parse_args() + + +def main(args): + aegis = Aegis(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[aegis]) + with timer("aegis safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/ar_config_tokenizer.py b/ar_config_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc20ac60191a1f8a3b250d7ecffa383a467cd5fc --- /dev/null +++ b/ar_config_tokenizer.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from .discrete_video import DiscreteVideoFSQStateDictTokenizer +from .ar_networks import CausalDiscreteVideoTokenizer +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict + + +def create_discrete_video_fsq_tokenizer_state_dict_config( + ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16] +) -> LazyDict: + CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( + # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + name="CausalDiscreteFactorizedVideoTokenizer", + ) + + return L(DiscreteVideoFSQStateDictTokenizer)( + enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"), + dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"), + tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig, + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0], + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=compression_ratio, + ) + + +@attrs.define(slots=False) +class TextTokenizerConfig: + """ + Text tokenizer config + + Args: + config: Config file to define the text tokenizer class. + data_key (str): The input key from data_dict that will be passed to the text tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. + vocab_size (int): Vocabulary size of the tokenizer. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = False + tokenizer_offset: int = 0 + vocab_size: int = 0 + + +@attrs.define(slots=False) +class VideoTokenizerConfig: + """ + Video tokenizer config + + Args: + config: Config file to define the video tokenizer class. + data_key (str): The input key from data_dict that will be passed to the video tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we + add an offset to make sure that video tokens and text tokens don't overlap. + vocab_size (int): Vocabulary size of the tokenizer. + max_seq_len (int): Maximum token length for an input video. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = True + tokenizer_offset: int = 0 + vocab_size: int = 0 + max_seq_len: int = -1 + + +@attrs.define(slots=False) +class TokenizerConfig: + """ + Joint tokenizer config + + Args: + text_tokenizer (TextTokenizerConfig): Text tokenizer config file + class_tokenizer (ClassTokenizerConfig): Class tokenizer config file + video_tokenizer (VideoTokenizerConfig): Video tokenizer config file + image_tokenizer (ImageTokenizerConfig): Image tokenizer config file + seq_len (int): Final token sequence length + training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"] + add_special_tokens (bool): Whether to add special tokens to the output tokens + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + """ + + text_tokenizer: Optional[TextTokenizerConfig] = None + video_tokenizer: Optional[VideoTokenizerConfig] = None + seq_len: int = 4096 + training_type: str = None + add_special_tokens: bool = True + pad_to_multiple_of: Optional[int] = 64 diff --git a/ar_configs_base_model.py b/ar_configs_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0282c5544a109d4dde473bd0856ee95cf880ed --- /dev/null +++ b/ar_configs_base_model.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from .ar_config_tokenizer import TokenizerConfig + + +@attrs.define +class ModelConfig: + """ + A class to hold model configuration arguments. + + Args: + dim (int): The dimensionality of the input and output of each transformer block. + n_layers (int): Number of layers in the transformer. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to + `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention. + head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads. + vocab_size (int): Vocabulary size. + ffn_hidden_size (int): Hidden size for feedforward network. + norm_eps (float): Epsilon value for normalization. + rope_theta (float): Theta value for rotary positional embeddings. + apply_abs_pos_emb (bool): Whether to apply absolute position embeddings. + max_batch_size (int): Maximum batch size for inference. + max_seq_len (int): Maximum sequence length for input text. + fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True. + causal_mask (bool): Whether to use causal mask. Defaults to True. + norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm". + precision (str): Data type for the model. + use_qk_normalization (bool): Whether to enable QK normalization. + ckpt_dir (str): Checkpoint directory. + ckpt_path (str): Checkpoint path. + apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension). + yarn_scale (Optional[float]): Scale factor for YaRN. + yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code) + yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code) + original_seq_len (Optional[int]): Original sequence length. + vision_encoder (Optional[str]): Vision encoder name. + mm_projector (Optional[str]): Multi-modal projector name. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D". + pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2". + original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + num_video_frames (Optional[int]): Number of video frames. + video_height (Optional[int]): Raw video pixel height dimension. + video_width (Optional[int]): Raw video pixel width dimension. + video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W). + """ + + dim: int = attrs.field(default=4096) + n_layers: int = attrs.field(default=32) + n_heads: int = attrs.field(default=32) + n_kv_heads: Optional[int] = attrs.field(default=8) + head_dim: Optional[int] = attrs.field(default=None) + vocab_size: int = attrs.field(default=128256) + ffn_hidden_size: int = attrs.field(default=14336) + norm_eps: float = attrs.field(default=1e-5) + rope_theta: float = attrs.field(default=500000) + apply_abs_pos_emb: bool = attrs.field(default=False) + max_batch_size: int = attrs.field(default=1) + max_seq_len: int = attrs.field(default=8192) + fuse_qkv: bool = attrs.field(default=False) + causal_mask: bool = attrs.field(default=True) + norm_type: str = attrs.field(default="rmsnorm") + precision: str = attrs.field(default="bfloat16") + use_qk_normalization: bool = False + tokenizer: Optional[TokenizerConfig] = None + ckpt_dir: Optional[str] = attrs.field(default=None) + ckpt_path: Optional[str] = attrs.field( + default=None + ) # If not None, load the model from this path instead of ckpt_dir + apply_yarn: Optional[bool] = attrs.field(default=False) + yarn_scale: Optional[float] = attrs.field(default=None) + yarn_beta_fast: Optional[int] = attrs.field(default=None) + yarn_beta_slow: Optional[int] = attrs.field(default=None) + original_seq_len: Optional[int] = attrs.field(default=None) + vision_encoder: Optional[str] = attrs.field(default=None) + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + mm_projector: Optional[str] = attrs.field(default=None) + rope_dim: Optional[str] = attrs.field(default="1D") + pytorch_rope_version: Optional[str] = attrs.field(default="v2") + original_latent_shape: Optional[list] = None + pad_to_multiple_of: Optional[int] = None + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + insert_cross_attn: bool = False + insert_cross_attn_every_k_layers: int = 1 + context_dim: Optional[int] = attrs.field(default=1024) + # For video training + num_video_frames: Optional[int] = None + # Raw video pixel dimension + video_height: Optional[int] = None + video_width: Optional[int] = None + # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact + video_latent_shape: Optional[list] = None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/ar_model.py b/ar_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ccde6728e0979a0c185fdf509e6b2748c63b8be5 --- /dev/null +++ b/ar_model.py @@ -0,0 +1,596 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from .log import log +import torch +from safetensors.torch import load_file +from torch.nn.modules.module import _IncompatibleKeys + +from .ar_configs_base_model import ModelConfig +from .ar_config_tokenizer import TokenizerConfig +from .mm_projector import MultimodalProjector +from .ar_transformer import Transformer +from .vit import VisionTransformer, get_vit_config +from .ar_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size +from .checkpoint import ( + get_partial_state_dict, + process_state_dict, + substrings_to_ignore, +) +from .sampling import decode_n_tokens, decode_one_token, prefill +from .misc import misc, Color, timer + + +class AutoRegressiveModel(torch.nn.Module): + """ + A class to build and use a AutoRegressiveModel model for text generation. + + Methods: + build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + generate: Generate text sequences based on provided prompts using the language generation model. + """ + + def __init__( + self, + model: Transformer = None, + tokenizer: DiscreteMultimodalTokenizer = None, + config: ModelConfig = None, + vision_encoder: VisionTransformer = None, + mm_projector: MultimodalProjector = None, + ): + """ + Initialize the AutoRegressiveModel instance with a model and tokenizer. + + Args: + model (Transformer): The Transformer model for text generation. + tokenizer (Tokenizer): The tokenizer for encoding and decoding text. + config (Config): The configuration for the AutoRegressiveModel model. + vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model. + mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model. + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + + self.vision_encoder = vision_encoder + self.mm_projector = mm_projector + + @property + def precision(self): + return self.model.precision + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def load_ar_model( + self, + tokenizer_config, + ): + """ + Load the AR model. + """ + model_config = self.config + ckpt_path = model_config.ckpt_path + with timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + tokenizer_config=tokenizer_config, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}" + ) + vocab_size = update_vocab_size( + existing_vocab_size=0, + to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size, + training_type=tokenizer_config.training_type, + add_special_tokens=False, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}" + ) + # Perform vocab expansion + if vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer_config.training_type == "text_to_video") + model.expand_vocab( + vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with timer("loading state_dict into model"): + missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + self.model = model.to(precision).to("cuda") + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + def load_tokenizer(self, tokenizer_config): + """ + Load the tokenizer. + """ + self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + + @staticmethod + def build( + model_config: ModelConfig = ModelConfig(), + tokenizer_config: TokenizerConfig = None, + ) -> "AutoRegressiveModel": + """ + Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + + Args: + model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig(). + tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None. + download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. + Returns: + AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory. + + Note: + This method sets the device to CUDA and loads the pre-trained model and tokenizer. + """ + # Initialize model configuration parameters + config_params = {} + + # Load checkpoint and model parameters + + if model_config.ckpt_path is None: + # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir + ckpt_dir = model_config.ckpt_dir + + # We prioritize safetensors version over the pytorch version, since the former is + # much faster for checkpoint loading. + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + if len(checkpoints) == 0: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert ( + len(checkpoints) == 1 + ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" + ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case + + if os.path.exists(Path(ckpt_dir) / "config.json"): + with open(Path(ckpt_dir) / "config.json", "r") as f: + config_params = json.loads(f.read()) + else: + log.info( + f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config." + ) + + else: + # If ckpt_path is provided, we load the model from the specified path, + # and use the default model configuration + ckpt_path = model_config.ckpt_path + + for key, value in config_params.items(): + if hasattr(model_config, key): + # Override the default model configuration with the parameters from the checkpoint + setattr(model_config, key, value) + + with timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + + if model_config.vision_encoder is not None: + # Take the LLM weights (starting with "model.") from the VLM checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + if model_config.vision_encoder is not None: + # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` + # and `checkpoint['mm_projector']` are both for those weights + # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights + if "vision_encoder" in checkpoint: + log.debug("Using pretrained vision_encoder") + vit_checkpoint = checkpoint["vision_encoder"] + else: + log.debug("Using fine-tuned vision_encoder") + vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") + vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") + if "mm_projector" in checkpoint: + log.debug("Using pretrained mm_projector") + projector_checkpoint = checkpoint["mm_projector"] + else: + log.debug("Using fine-tuned mm_projector") + projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") + projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") + assert ( + len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 + ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." + + tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + tokenizer_config=tokenizer_config, + ) + model_kwargs = {} + + if model_config.vision_encoder is not None: + assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." + vit_config = get_vit_config(model_config.vision_encoder) + vision_encoder = VisionTransformer.build( + vit_config, + ) + + mm_projector = MultimodalProjector( + mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] + ) + model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) + + # Perform vocab expansion + if tokenizer.vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {tokenizer.vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer.training_type == "text_to_video") + model.expand_vocab( + tokenizer.vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + if model_config.vision_encoder is not None: + vision_encoder.load_state_dict(vit_checkpoint) + mm_projector.load_state_dict(projector_checkpoint) + if model_config.vision_encoder_in_channels != 3: + vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) + + model = model.to(precision) # ensure model parameters are in the correct precision + log.debug(f"Model config: {model_config}") + + model_class = AutoRegressiveModel + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return model_class(model, tokenizer, model_config, **model_kwargs) + + @torch.no_grad() + def generate( + self, + prompt_tokens: List[List[int]] | torch.Tensor, + max_gen_len: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + echo: bool = False, + seed: int = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + compile_sampling: bool = True, + compile_prefill: bool = False, + verbose: bool = True, + stop_tokens: Optional[Set[int]] = None, + images: Optional[torch.Tensor] = None, + ): + """ + Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast). + + Args: + prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. + num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. + seed (int, optional): Random seed for reproducibility. Defaults to None. + compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. + """ + assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." + if temperature == 0: + top_p, top_k = None, None + log.debug("Setting top_p and top_k to None because temperature is 0") + if top_p is not None: + log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + orig_precision = torch.get_default_dtype() + torch.set_default_dtype(self.precision) + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + if seed is not None: + misc.set_random_seed(seed) + + assert not logprobs, "logprobs are not supported for fast_generate yet" + # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags + if compile_sampling and not getattr(self, "inference_decode_compiled", False): + self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + self.inference_decode_compiled = True + log.info("Compiled AR sampling function. Note: the first run will be slower due to compilation") + if compile_prefill and not getattr(self, "inference_prefill_compiled", False): + self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + self.inference_prefill_compiled = True + log.info("Compiled prefill function. Note: the first run will be slower due to compilation") + + if not hasattr(self, "decode_one_token"): + self.decode_one_token = decode_one_token + if not hasattr(self, "prefill"): + self.prefill = prefill + + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.debug( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + if isinstance(prompt_tokens, list): + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") + if prompt_tokens.ndim == 1: + prompt_tokens = prompt_tokens.view(1, -1) + else: + assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" + batch_size, prompt_len = prompt_tokens.shape + total_len = min(params.max_seq_len, max_gen_len + prompt_len) + if max_gen_len + prompt_len > params.max_seq_len: + log.warning( + f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" + ) + max_gen_len = params.max_seq_len - prompt_len + + if context_mask is not None: + context_mask = context_mask.to(dtype=torch.bool) + if context_mask.ndim == 2: + assert ( + context_mask.shape[0] == batch_size + ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" + # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] + context_mask = context_mask.view(batch_size, 1, 1, -1) + + if num_gen_seq > 1: + assert ( + batch_size == 1 + ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" + log.debug(f"Generating {num_gen_seq} sequences with the same prompt") + assert ( + num_gen_seq <= params.max_batch_size + ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" + # repeat the prompt tokens for num_gen_seq times + prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) + assert prompt_tokens.shape == ( + num_gen_seq, + prompt_len, + ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" + batch_size = len(prompt_tokens) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) + empty[:, :prompt_len] = prompt_tokens + seq = empty + input_pos = torch.arange(0, prompt_len, device="cuda") + + if verbose: + prefill_start = time.time() + + if images is not None: + images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16) + prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images) + else: + prompt_token_embeddings = None + + if context is not None: + context = context.to(device=prompt_tokens.device, dtype=self.precision) + + # Prefill stage + next_token = self.prefill( + self.model, + input_pos=input_pos, + tokens=prompt_tokens if prompt_token_embeddings is None else None, + token_embeddings=prompt_token_embeddings, + temperature=temperature, + top_k=top_k, + top_p=top_p, + context=context, + context_mask=context_mask, + ) + if verbose: + prefill_time = time.time() - prefill_start + + seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) + input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") + stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens + stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") + + if verbose: + decode_start = time.time() + # Decode stage + generated_tokens = decode_n_tokens( + self.model, + next_token.view(batch_size, -1), + input_pos, + max_gen_len - 1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, + decode_one_token_function=self.decode_one_token, + context=context, + context_mask=context_mask, + ) + gen_len = len(generated_tokens) + if verbose: + decode_time = time.time() - decode_start + prefill_throughput = prompt_len / prefill_time + decode_throughput = gen_len / decode_time + log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") + log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") + + generated_tokens = torch.cat(generated_tokens, dim=1) + + log.debug(f"generated_tokens: {generated_tokens.shape}") + seq = seq[:, : prompt_len + 1 + gen_len] + seq[:, prompt_len + 1 :] = generated_tokens + if not echo: + seq = seq[:, prompt_len:] + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return seq, None + + def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: + """ + Embed vision and language features into a combined representation. + + Args: + input_ids (torch.Tensor): Input token IDs. + images (torch.tensor): Input images. + + Returns: + torch.Tensor: Combined vision-language features. + + Raises: + AssertionError: If vision encoder or mm projector is not initialized, + or if dimensions mismatch. + """ + # Ensure vision encoder and mm projector are initialized + assert self.vision_encoder is not None + assert self.mm_projector is not None + + # Get image token ID and validate it + image_token_id = self.vision_encoder.image_token_id + assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" + + # Identify text and image locations in the input + text_locations = input_ids != image_token_id + image_locations = input_ids == image_token_id + + # Process text features + text_features = self.model.tok_embeddings(input_ids[text_locations]) + + # Process image features + images = images.to(device=text_features.device, dtype=text_features.dtype) + vit_outputs = self.vision_encoder(images) + image_features = self.mm_projector(vit_outputs) + + # Get dimensions + B, seq_len = input_ids.shape + N_total = B * seq_len + N_txt, D_txt = text_features.shape + N_img, N_patch, D_img = image_features.shape + + # Reshape image features + image_features = image_features.reshape(N_img * N_patch, D_img) + + # Validate dimensions + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + N_total == N_txt + N_img * N_patch + ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" + + # Combine text and image features + combined_features = torch.empty( + (B, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + combined_features[image_locations, :] = image_features + + return combined_features + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if strict: + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + return _IncompatibleKeys(actual_missing_keys, unexpected_keys) diff --git a/ar_modules_attention.py b/ar_modules_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..865317a80aa2af41055cedac2e1bda5f0af88f14 --- /dev/null +++ b/ar_modules_attention.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch +from torch import nn + +from .ar_modules_embedding import RotaryPositionEmbedding +from .ar_modules_normalization import create_norm + + +class Attention(nn.Module): + """ + Attenion layer with KV cache. + """ + + def __init__( + self, + n_heads: int, + n_kv_heads: Union[int, None], + dim: int, + max_batch_size: int, + max_seq_len: int, + context_dim: Optional[int] = None, + use_qk_normalization: bool = False, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + causal_mask: Optional[bool] = True, + head_dim: Optional[int] = None, + fuse_qkv: bool = False, + precision: str = "bfloat16", + attn_type: str = "self", + ): + """ + Initializes the GQA module. + + Args: + n_heads (int): The number of attention heads. + n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. + dim (int): The dimensionality of the input and output. + max_batch_size (int): The maximum batch size. + max_seq_len (int): The maximum sequence length. + context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. + use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. + norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". + norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. + causal_mask (bool, optional): Whether to use causal mask. Defaults to True. + head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. + fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False. + precision (str, optional): The precision of the module. Defaults to "bfloat16". + attn_type (str, optional): The type of attention. Defaults to "self". + """ + super().__init__() + assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}" + self.attn_type = attn_type + context_dim = dim if context_dim is None else context_dim + + self.dim = dim + self.context_dim = context_dim + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_local_heads = n_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads if head_dim is None else head_dim + self.causal_mask = causal_mask + self.fuse_qkv = fuse_qkv + self.precision = precision + + if fuse_qkv: + assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" + self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim + self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) + # Register hook to load fused QKV weights + self._register_load_state_dict_pre_hook(self.load_hook) + else: + self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) + + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + + if self.attn_type == "self": + # Cache for key and value tensors + self.init_kv_cache() + + # QK normalization layers + if use_qk_normalization: + self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + + self.use_qk_normalization = use_qk_normalization + + self.to(dtype=getattr(torch, self.precision)) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def init_kv_cache(self, dtype=None): + cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) + if dtype is None: + dtype = getattr(torch, self.precision) + if self.attn_type == "self": + self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() + self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbedding, + input_pos: torch.Tensor, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ): + """ + Forward pass of GQA. + + Args: + x: The input tensor of shape (batch_size, seq_len, dim). + rope: The rotary positional embedding module. + input_pos: The starting position of the current sequence. + mask: The attention mask tensor. + context: The context tensor of shape (batch_size, context_len, dim). + + Returns: + The output tensor after applying GQA. + """ + bsz, seqlen, _ = x.shape + + # Use one single module to handle both self-attn and cross-attn + context = x if context is None else context + context_len = seqlen if context is None else context.shape[1] + + if self.fuse_qkv: + q_size = self.n_local_heads * self.head_dim + kv_size = self.n_local_kv_heads * self.head_dim + xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + else: + # Compute query, key, and value projections + xq, xk, xv = self.wq(x), self.wk(context), self.wv(context) + + # Reshape projections + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + + # QK normalization + if self.use_qk_normalization: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Apply rotary positional embeddings to queries and keys + # Only apply RoPE to self-attention! + if self.attn_type in ["self", "full"]: + xq, xk = rope(xq, xk, input_pos, seqlen) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_kv_heads, cache_len + context_len, head_dim) + # xv: (bs, n_kv_heads, cache_len + context_len, head_dim) + if self.attn_type == "self": + # Update cache with current key and value tensors + assert input_pos is not None + self.cache_k[:bsz, :, input_pos] = xk + self.cache_v[:bsz, :, input_pos] = xv + keys, values = ( + self.cache_k[:bsz, :, :], + self.cache_v[:bsz, :, :], + ) + else: + keys, values = xk, xv + + # Repeat keys and values if necessary + keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + + # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used, + # since the masking is handled outside this attention module. + # For cross-attention, it's always full-attn without causal mask + is_causal = False + output = scaled_dot_product_attention( + xq, + keys, + values, + head_dim=self.head_dim, + mask=mask, + is_causal=is_causal, + dropout_p=0.0, + ) + output = output.view(bsz, seqlen, -1) + output = self.wo(output) + return output + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_dim: int, + mask: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + dropout_p: float = 0.0, +) -> torch.Tensor: + """ + PyTorch's native implementation of Flash Attention 2. + + If `is_causal` is given, then the causal attention mask is applied accordingly: + - If `is_causal` is True, the standard upper-left causal attention masking is applied. + - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is + provided (i.e., `mask is not None`). + + If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied + based on the provided mask tensor: + - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, + leading to the standard upper-left causal attention masking. + - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, + and `is_causal` is set to False. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + head_dim (int): Dimension of each attention head + mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. + dropout_p (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + torch.Tensor: Output tensor after applying scaled dot-product attention + """ + scale = 1.0 / math.sqrt(head_dim) + if is_causal is None: + is_causal = mask is None + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + ) + return y.transpose(1, 2).contiguous() diff --git a/ar_modules_embedding.py b/ar_modules_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..2a19d3b8e2a1cf29c182f7b25a25d4c1e10089da --- /dev/null +++ b/ar_modules_embedding.py @@ -0,0 +1,491 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +from einops import rearrange, repeat + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _rotate_half_te(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even]. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb_te( + t: torch.Tensor, + cos_freqs: torch.Tensor, + sin_freqs: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[b, s, h, d]`, on which + rotary positional embedding will be applied. + cos_freqs: torch.Tensor + Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + sin_freqs: torch.Tensor + Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + """ + rot_dim = cos_freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs) + output = torch.cat((t, t_pass), dim=-1) + return output + + +class RotaryPositionEmbedding(torch.nn.Module): + """ + Rotary Position Embedding module as described in the paper: + https://arxiv.org/abs/2104.09864 + + This module implements rotary positional embeddings, which are used to + enhance the performance of transformer models. + + Args: + dim (int): Dimensionality of the input tensor. + max_position_embeddings (Optional[int]): Maximum position embeddings. + original_max_position_embeddings (Optional[int]): Original maximum position embeddings. + rope_theta (Optional[float]): Base for the frequency calculation. + apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary). + scale (Optional[int]): Scaling factor for the frequency calculation. + extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension. + attn_factor (Optional[int]): Attention factor for the frequency calculation. + beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation. + beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: Optional[int] = None, + original_max_position_embeddings: Optional[int] = None, + rope_theta: Optional[float] = 10000.0, + apply_yarn: Optional[bool] = False, + scale: Optional[int] = None, + extrapolation_factor: Optional[int] = 1, + attn_factor: Optional[int] = 1, + beta_fast: Optional[int] = 32, + beta_slow: Optional[int] = 1, + rope_dim: Optional[str] = "1D", + latent_shape: Optional[List[int]] = None, + original_latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.rope_theta = rope_theta + self.apply_yarn = apply_yarn + self.scale = scale + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = 1.0 + self.rope_dim = rope_dim + self.latent_shape = latent_shape + self.original_latent_shape = original_latent_shape + self.pad_to_multiple_of = pad_to_multiple_of + self.get_inv_freq(torch.cuda.current_device()) + + def get_mscale(self, scale: float = 1.0) -> float: + """Get the magnitude scaling factor for YaRN.""" + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + def forward(self, seq_len: Optional[int] = None) -> torch.Tensor: + """ + Forward pass for the rotary position embedding. + + Args: + seq_len (Optional[int]): Length of the sequence. + + Returns: + torch.Tensor: The computed frequencies for positional embedding. + """ + + if self.apply_yarn and seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.freqs = self.compute_freqs() + + return self.freqs + + def compute_freqs( + self, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the spatial frequencies for the latent tensor.""" + self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda() + if self.rope_dim == "1D": + emb = torch.einsum("i,j->ij", self.seq, self.inv_freq) + + elif self.rope_dim == "2D": + H, W = self.latent_shape + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_h, "h d -> h w d", w=W), + repeat(half_emb_w, "w d -> h w d", h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "h w d -> (h w) 1 1 d").float() + + elif self.rope_dim == "3D": + T, H, W = self.latent_shape + half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq) + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + return emb + + def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: + """Get the scale factors for YaRN.""" + # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called + # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. + high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len + low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len + # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear + # interpolation in between. + smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) + # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. + scale_factors = (1 - smooth_mask) / self.scale + smooth_mask + return scale_factors + + def get_inv_freq(self, device: torch.device) -> None: + """Get the inverse frequency.""" + if self.rope_dim == "1D": + assert self.max_position_embeddings is not None, "Max position embeddings required." + inv_freq = 1.0 / ( + self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + if self.apply_yarn: + assert self.original_max_position_embeddings is not None, "Original max position embeddings required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings) + # Apply the scaling factors to inv_freq. + inv_freq = inv_freq * scale_factors + # Set the magnitude scaling factor. + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.max_seq_len_cached = self.max_position_embeddings + self.inv_freq = inv_freq + + elif self.rope_dim == "2D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 2 + spatial_inv_freq = 1.0 / ( + self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h + ) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0]) + spatial_inv_freq = spatial_inv_freq * scale_factors + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + + elif self.rope_dim == "3D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 6 * 2 + dim_t = self.dim - 2 * dim_h + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h + spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t + temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) + spatial_inv_freq = spatial_inv_freq * scale_factors_spatial + scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) + temporal_inv_freq = temporal_inv_freq * scale_factors_temporal + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.temporal_inv_freq = temporal_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + + self.freqs = self.compute_freqs() + + +class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as the TransformerEngine RoPE + (https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) + + """ + + def __init__( + self, + seq_len: int, + training_type: str = None, + **kwargs, + ): + super().__init__( + **kwargs, + ) + emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type) + emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim] + assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}" + # cos/sin first then dtype conversion for better precision + self.register_buffer("cos_cached", torch.cos(emb), persistent=False) + self.register_buffer("sin_cached", torch.sin(emb), persistent=False) + + def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor: + """ + Create rotary position embedding frequencies. + + Args: + seq_len (int): Sequence length of a sample. + + Returns: + torch.Tensor: The computed positional embeddings. + """ + if self.rope_dim == "1D": + freqs = super().forward(seq_len=seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + elif self.rope_dim in ["2D", "3D"]: + emb = super().forward(seq_len=seq_len) + if training_type == "text_to_video": + # since we added token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) + emb = torch.cat((bov_pe, emb), dim=0) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) + + return emb + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if q.dtype != self.cos_cached.dtype: + self.cos_cached = self.cos_cached.to(q.dtype) + self.sin_cached = self.sin_cached.to(q.dtype) + + cos_emb = self.cos_cached + sin_emb = self.sin_cached + if input_pos is not None: + cos_emb = cos_emb[:, input_pos, :, :] + sin_emb = sin_emb[:, input_pos, :, :] + elif seq_len is not None: + cos_emb = cos_emb[:, :seq_len, :, :] + sin_emb = sin_emb[:, :seq_len, :, :] + q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb) + k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb) + return q, k + + +class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as + mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py) + or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py) + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + if self.rope_dim == "1D": + emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(self.freqs, "s 1 1 d -> s d").float() + self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input tensor.""" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + x1 = x_reshaped[..., 0] + x2 = x_reshaped[..., 1] + output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) + return output + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the rotary position embedding. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + input_pos (Optional[torch.Tensor]): Starting position for the sequence. + seq_len (Optional[int]): Length of the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + if self.apply_yarn and seq_len > self.max_seq_len_cached: + freqs = super().forward(seq_len) + if self.rope_dim == "1D": + emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(freqs, "s 1 1 d -> s d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + + if input_pos is not None: + cos_cached = self.cos_cached[:, input_pos] + sin_cached = self.sin_cached[:, input_pos] + else: + assert ( + self.cos_cached.shape[1] >= seq_len + ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." + cos_cached = self.cos_cached[:, :seq_len, ...] + sin_cached = self.sin_cached[:, :seq_len, ...] + xq = q * cos_cached + self.rotate_half(q) * sin_cached + xk = k * cos_cached + self.rotate_half(k) * sin_cached + + return xq.type_as(q), xk.type_as(k) + + +class SinCosPosEmbAxisTE(torch.nn.Module): + def __init__( + self, + dim: int, + latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, + **kwargs, + ): + """ + Args: + dim (int): Dimensionality of the input tensor. + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + dtype (torch.dtype): Data type of the position embedding tensor. + """ + super().__init__() + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.latent_shape = latent_shape + T, H, W = latent_shape + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H)) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W)) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T)) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device="cuda"), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device="cuda"), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device="cuda"), persistent=False) + self.pad_to_multiple_of = pad_to_multiple_of + + def forward( + self, + training_type: str = None, + ) -> torch.Tensor: + T, H, W = self.latent_shape + emb = torch.cat( + [ + repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W), + repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W), + repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H), + ], + dim=-1, + ) + # Flatten the T,H,W dimensions + emb = rearrange(emb, "t h w d -> (t h w) d") + + if training_type == "text_to_video": + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype) + emb = torch.cat((bov_pe, emb), dim=0) + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0) + seq_len, dim = emb.shape + emb = emb.reshape(1, seq_len, dim) + return emb diff --git a/ar_modules_mlp.py b/ar_modules_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..45a2ac6c32e8df9e6836ed55973912b8730c0749 --- /dev/null +++ b/ar_modules_mlp.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + """ + Initializes the multilayer perceptron (MLP) module. + + Args: + dim: The input and output dimensionality. + hidden_dim: The dimensionality of the hidden layer. + """ + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the MLP module. + + Args: + x: The input tensor of shape (batch_size, dim). + + Returns: + The output tensor of shape (batch_size, dim). + """ + output = self.w2(F.silu(self.w1(x)) * self.w3(x)) + return output diff --git a/ar_modules_normalization.py b/ar_modules_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..61f10fe07227a01d582e17f89a9b5089aa506006 --- /dev/null +++ b/ar_modules_normalization.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def create_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Creates the specified normalization layer based on the norm_type. + Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py + + Args: + norm_type (str): The type of normalization layer to create. + Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The created normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps, compile=False) + elif norm_type == "compiled_rmsnorm": + return RMSNorm(dim, eps=eps, compile=True) + elif norm_type == "fused_rmsnorm": + raise NotImplementedError("Fused RMSNorm is not supported yet.") + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + compile (bool, optional): Whether to compile the forward function. Default is False. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm + + @staticmethod + def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): + def _norm(x, eps): + # Computes the root-mean-square norm of the input tensor. + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + output = _norm(x.float(), eps).type_as(x) + return output * weight + + def forward(self, x: torch.Tensor): + return self.rmsnorm_fn(x, self.weight, self.eps) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) diff --git a/ar_networks.py b/ar_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..29be4d33e5dfb6255b5db0b99bcbc4311a3faa82 --- /dev/null +++ b/ar_networks.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +import torch +from torch import nn + +from .ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized +from .ar_tokenizer_quantizers import FSQuantizer +from .log import log + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) + self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + self.quantizer = FSQuantizer(**kwargs) + + num_parameters = sum(param.numel() for param in self.parameters()) + log.debug(f"model={self.name}, num_parameters={num_parameters:,}") + log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) + return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) diff --git a/ar_tokenizer.py b/ar_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8c3c11cf728448ba0a4149373dc877cd1d77a0 --- /dev/null +++ b/ar_tokenizer.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Optional + +import torch +from einops import rearrange + +from .ar_config_tokenizer import TokenizerConfig +from .lazy_config_init import instantiate as lazy_instantiate + + +def update_vocab_size( + existing_vocab_size, + to_be_added_vocab_size, + training_type, + add_special_tokens, + video_special_tokens={}, +): + # New vocab size + if add_special_tokens: + existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens) + # For text_to_video, we add one special token at the beginning of the video + elif training_type == "text_to_video": + existing_vocab_size += to_be_added_vocab_size + 1 + else: + existing_vocab_size += to_be_added_vocab_size + return existing_vocab_size + + +class DiscreteMultimodalTokenizer: + def __init__(self, tokenizer_config: TokenizerConfig): + self.tokenizer_config = tokenizer_config + self.vocab_size = 0 + self.total_seq_len = tokenizer_config.seq_len + self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of + self.training_type = tokenizer_config.training_type + assert self.training_type in [ + "text_only", + "text_to_video", + "video_to_video", + "image_text_interleaved", + ], f"{self.training_type} not supported" + + self._build_text_tokenizer() + self._build_video_tokenizer() + + def _build_text_tokenizer(self): + r"""Function to initialize the text tokenizer model.""" + if self.tokenizer_config.text_tokenizer is not None: + self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config) + self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size + else: + self.text_tokenizer = None + + def _build_video_tokenizer(self): + r"""Function to initialize the video tokenizer model.""" + if self.tokenizer_config.video_tokenizer is not None: + self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config) + self.video_tokenizer = self.video_tokenizer.to("cuda") + self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size + special_token_offset = ( + self.tokenizer_config.video_tokenizer.tokenizer_offset + + self.tokenizer_config.video_tokenizer.vocab_size + ) + self.video_special_tokens = { + "<|begin_of_video|>": special_token_offset, + "<|end_of_video|>": special_token_offset + 1, + "<|pad_token_video|>": special_token_offset + 2, + } + + self.vocab_size = update_vocab_size( + existing_vocab_size=self.vocab_size, + to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size, + training_type=self.training_type, + add_special_tokens=self.tokenizer_config.add_special_tokens, + video_special_tokens=self.video_special_tokens, + ) + else: + self.video_tokenizer = None + + @property + def pad_id(self): + r"""Returns the pad_id.""" + + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + pad_id = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + pad_id = self.video_special_tokens["<|pad_token_video|>"] + else: + raise ValueError(f"training_type {self.training_type} not defined") + return pad_id + + @property + def ignore_index(self): + r"""Returns which token should be ignored during loss computation.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id: + # If the PAD token is the same as the EOS token, we do not ignore it during loss + # computation, since we want the model to be able to predict EOS tokens in inference. + # The PyTorch default ignore_index for the cross-entropy loss is -100. + ignore_index = -100 + else: + ignore_index = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + ignore_index = self.pad_id + else: + raise ValueError(f"training_type {self.training_type} not defined") + return ignore_index + + @property + def stop_tokens(self): + r"""Returns the stop tokens.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + stop_tokens = self.text_tokenizer.stop_tokens + elif self.training_type in ["text_to_video", "video_to_video"]: + stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]]) + else: + raise ValueError(f"training_type {self.training_type} not defined") + return stop_tokens + + def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1): + r"""Function to tokenize text. + Args: + raw_text (list[str]): List of input strings + max_text_seq_len (int): Maximum sequence length returned by text tokenizer + Returns: + text_tokens (list[list[int]]): List of text tokens + """ + + batch_size = len(raw_text) + text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)] + + # Clipping the text tokens so that the sequence length does not exceed max_text_seq_len + if max_text_seq_len > -1: + for i in range(len(text_tokens)): + if len(text_tokens[i]) > max_text_seq_len: + # Simply clip and add end of seq token + text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id] + return text_tokens + + def _tokenize_class(self, cls_labels: list[str]): + r"""Function to tokenize the class label. + Args: + cls_labels (list[str]): List of class indices + Returns: + class_tokens (list[list[int]]): List of class tokens + """ + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels] + + return class_tokens + + def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None): + r"""Function to tokenize video. + Args: + videos (torch.Tensor): Input video data tensor + pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer. + Returns: + video_tokens (list[list[int]]): List of video tokens + """ + + video_tokens = [] + batch_size = videos.shape[0] + + quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration) + indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) + + # Flatten the indices + indices = rearrange(indices, "B T H W -> B (T H W)") + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + indices += self.tokenizer_config.video_tokenizer.tokenizer_offset + + # Add begin and end of video tokens + bov_token = self.video_special_tokens["<|begin_of_video|>"] + eov_token = self.video_special_tokens["<|end_of_video|>"] + + # Append bov and eov tokens + if self.tokenizer_config.add_special_tokens: + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist() + [eov_token]) + else: + if self.training_type == "text_to_video": + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist()) + else: + for i in range(batch_size): + video_tokens.append(indices[i].tolist()) + assert ( + len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len + ), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}" + + return video_tokens + + def tokenize(self, data_batch: dict): + r"""Function to tokenize data_dict. + Args: + data_batch (dict): Input data dict + Returns: + tokens (torch.LongTensor): Token tensor dict + """ + + if ( + self.training_type in ["text_only", "image_text_interleaved"] + and not self.tokenizer_config.text_tokenizer.tokenize_here + ): + # In case of pre-computed tokens, just return the data_batch + return data_batch["tokens"], None + + # Online tokenization + tokens = [] + token_boundaries = defaultdict(list) + + # Obtain maximum sequence length + max_text_seq_len = -1 + max_visual_seq_len = -1 + + if self.training_type in ["text_to_video", "video_to_video"]: + max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len + + # If max visual sequence length is specified, make sure that text is clipped so that + # the full video/image is always seen. + if max_visual_seq_len > -1: + if self.tokenizer_config.add_special_tokens: + max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token + elif self.training_type == "text_to_video": + max_visual_seq_len = max_visual_seq_len + 1 + else: + max_visual_seq_len = max_visual_seq_len + assert ( + max_visual_seq_len <= self.total_seq_len + ), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})" + max_text_seq_len = self.total_seq_len - max_visual_seq_len + + # Tokenize the text + if ( + "text" in self.training_type + and self.text_tokenizer is not None + and self.tokenizer_config.text_tokenizer.tokenize_here + ): + key = self.tokenizer_config.text_tokenizer.data_key + batch_size = len(data_batch[key]) + assert key in data_batch, f"Key {key} should be present in data for text tokenizer" + tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len) + + for i in range(batch_size): + token_boundaries["text"].append((0, len(tokens[i]))) + else: + tokens = [] + batch_size = None + + # Tokenize the class label + if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None: + key = self.tokenizer_config.class_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for class tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + tokens_class = self._tokenize_class(data_batch[key]) + if len(tokens) == 0: + tokens = tokens_class + for i in range(batch_size): + token_boundaries["class"].append((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i]))) + tokens[i] = tokens[i] + tokens_class[i] + + # Tokenize the video + if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here: + key = self.tokenizer_config.video_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for video tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + + pixel_chunk_duration = ( + None # If not specified, we assume it's a video dataset and use the default chunk duration + ) + dataset_name = data_batch.get("dataset_name", None) + if dataset_name is not None and dataset_name.startswith("image"): + # If it's an image dataset, we use a pixel chunk duration of 1 + pixel_chunk_duration = 1 + tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration) + if len(tokens) == 0: + tokens = tokens_video + for i in range(batch_size): + token_boundaries["video"].append((0, len(tokens[i]))) + # [B,] each entry is ((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i]))) + tokens[i] = tokens[i] + tokens_video[i] + + # Combine the tokens and do padding + max_seq_len_in_batch = max([len(token) for token in tokens]) + if self.pad_to_multiple_of is not None: + # Pad the sequence length to the nearest multiple of pad_to_multiple_of + max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of + pad_to_len = min(max_seq_len_in_batch, self.total_seq_len) + for i in range(len(tokens)): + if len(tokens[i]) < pad_to_len: + tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i])) + else: + tokens[i] = tokens[i][0:pad_to_len] + + # Convert it to long tensor + tokens = torch.LongTensor(tokens) + return tokens, token_boundaries diff --git a/ar_tokenizer_image_text_tokenizer.py b/ar_tokenizer_image_text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5877aa166d1d946b98ce604e2bd1a4284b884ae6 --- /dev/null +++ b/ar_tokenizer_image_text_tokenizer.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import transformers +from transformers import AutoImageProcessor +from transformers.image_utils import ImageInput, is_valid_image, load_image + +from .ar_tokenizer_text_tokenizer import TextTokenizer +from .log import log + +# Configuration for different vision-language models +IMAGE_CONFIGS = { + "pixtral": { + "patch_size": 16, + "image_token": "[IMG]", + "image_break_token": "[IMG_BREAK]", + "image_end_token": "[IMG_END]", + } +} + +# Chat template for Pixtral-12B-Instruct +PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}' + + +# Copied from transformers.models.pixtral.processing_pixtral.is_url +def is_url(val) -> bool: + """Check if the given value is a URL.""" + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url +def is_image_or_image_url(elem): + """Check if the given element is an image or an image URL.""" + return is_url(elem) or is_valid_image(elem) + + +def load_image_list( + image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None +) -> List["PIL.Image.Image"]: + """ + Load a list of images. + + Args: + image_list (List[Union[str, PIL.Image.Image]]): The list of images to load. + timeout (Optional[float]): The timeout for loading the image. + + Returns: + List[PIL.Image.Image]: The list of loaded images. + """ + return [load_image(image, timeout=timeout) for image in image_list] + + +class ImageTextTokenizer(TextTokenizer): + """ + Image-text tokenizer class that extends the text tokenizer to support vision tokens as well. + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + tokenizer_path: str, + image_processor_path: str, + ): + """ + Initialize the ImageTextTokenizer. + + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret". + + Raises: + AssertionError: If the model family is not supported or if the transformers version is incompatible. + """ + super().__init__( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ) + assert model_family in ["pixtral"], f"Unsupported model family: {model_family}" + if model_family == "pixtral": + # Need transformers>=4.45.0 + assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0" + assert is_instruct_model, "Pixtral requires is_instruct_model=True" + if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None: + setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE) + log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}") + + # Set up image-specific configurations + image_config = IMAGE_CONFIGS[model_family] + self.patch_size = image_config["patch_size"] + self.image_token = image_config["image_token"] + self.image_break_token = image_config["image_break_token"] + self.image_end_token = image_config["image_end_token"] + + # Initialize the image processor + self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + **text_kwargs, + ) -> List[int]: + """ + Process the images and return the tokenized images and text. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + **text_kwargs: Additional keyword arguments for text processing. + + Returns: + A dictionary with the following fields: + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. + + Raises: + ValueError: If the input images are in an invalid format. + """ + + output_dict, image_inputs = {}, {} + if images is not None: + # Preprocess images + if is_image_or_image_url(images): + images = [[images]] + elif isinstance(images, list) and is_image_or_image_url(images[0]): + images = [images] + elif ( + not isinstance(images, list) + and not isinstance(images[0], list) + and not is_image_or_image_url(images[0][0]) + ): + raise ValueError( + "Invalid input images. Please provide a single image or a list of images or a list of list of images." + ) + + # Load and process images + images = [load_image_list(sample) for sample in images] + image_kwargs = image_kwargs or {} + image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs) + + # Validate image inputs + assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs" + assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs" + assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format( + image_inputs.keys() + ) + + # Extract pixel values and image sizes + pixel_values = image_inputs["pixel_values"][0] + image_sizes = image_inputs["image_sizes"][0] + unique_sizes = np.unique(image_sizes, axis=0) + + assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes) + + # Convert pixel values to PyTorch tensor + pixel_values = np.asarray(pixel_values) + pixel_values = torch.from_numpy(pixel_values) + output_dict["pixel_values"] = pixel_values + output_dict["image_sizes"] = image_sizes + + # Expand image tokens in text + if image_inputs.get("pixel_values") is not None: + replace_strings = [] + # Calculate the number of tokens needed for each image and create a placeholder + for image_size in image_sizes: + height, width = image_size + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + text = text.replace(self.image_token, "", 1) + + # Replace placeholders with actual image token sequences + while "" in text: + replace_str = replace_strings.pop(0) + text = text.replace("", replace_str, 1) + + # Encode the text + text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs) + + output_dict["input_ids"] = text_inputs + return output_dict + + def apply_chat_template( + self, + conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]], + *, + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = True, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Apply the chat template to the conversation. + + Args: + conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process. + images (Optional[ImageInput]): Images to include in the conversation. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + add_generation_prompt (bool): Whether to add a generation prompt. + tokenize (bool): Whether to tokenize the output. + padding (bool): Whether to pad the output. + truncation (bool): Whether to truncate the output. + max_length (Optional[int]): Maximum length of the output. + return_tensors (Optional[str]): The type of tensors to return. + return_dict (bool): Whether to return a dictionary. + return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. + **kwargs: Additional keyword arguments. + + Returns: + The processed conversation with applied chat template. + + Raises: + AssertionError: If return_dict is False or if the conversation format is invalid. + """ + assert return_dict, "return_dict must be True for ImageTextTokenizer" + assert isinstance(conversation, list), "conversation must be a list" + if isinstance(conversation[0], list): + assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation) + conversation = conversation[0] + + # Extract images from the conversation if not provided + if images is None: + images = [] + for msg in conversation: + if msg.get("images", None) is not None: + images = images + (msg["images"]) + images = load_image_list(images) + # In case the input does not have images, will ignore + # Useful in feeding VLM inputs with and without images + if isinstance(images, list) and len(images) == 0: + images = None + + # Apply the chat template to the text + text = super().apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=add_generation_prompt, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=False, + return_assistant_tokens_mask=return_assistant_tokens_mask, + generation_prefix=generation_prefix, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + # Encode the text and images + output = self.encode( + text, + images=images, + image_kwargs=image_kwargs, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + return output + + @property + def model_input_names(self): + """ + Get the combined model input names from both the text tokenizer and image processor. + + Returns: + List[str]: A list of unique input names. + """ + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/ar_tokenizer_modules.py b/ar_tokenizer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2f9c6280ccfa60e1ba8a38e3062e0caf99e71e --- /dev/null +++ b/ar_tokenizer_modules.py @@ -0,0 +1,560 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/ +magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ar_tokenizer_patching import Patcher3D, UnPatcher3D +from .ar_tokenizer_utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) +from .log import log + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalHybridUpsample3d(nn.Module): + def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0) + if temporal_up + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1) + if spatial_up + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_up or temporal_up + else nn.Identity() + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0) + if spatial_down + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0) + if temporal_down + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_down or temporal_down + else nn.Identity() + ) + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, spatial_down=spatial_down, temporal_down=temporal_down + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed in the encoder should correspond + # to the layer index, inreverse order, where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + # For example: + # Input tensor = (1, 3, 17, 32, 32) + # Patch size = 4 for 3D wavelet transform + # Compression rate = (8x16x16) + # + # We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)` + # + # if legacy_mode is True, the temporal upsampling is not perfectly mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)` + # + # Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored. + # Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/ar_tokenizer_patching.py b/ar_tokenizer_patching.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5b621f9d526cff7966c77225656e9327adde30 --- /dev/null +++ b/ar_tokenizer_patching.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The patcher and unpatcher implementation for 2D and 3D data.""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=self.patch_size, p2=self.patch_size).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", patch_size * torch.ones([1], dtype=torch.int32), persistent=_PERSISTENT + ) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + return x + + def _iarrange(self, x): + x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/ar_tokenizer_quantizers.py b/ar_tokenizer_quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..e07b51aef6f32fb39266c2f12de27c9ff87eb4d7 --- /dev/null +++ b/ar_tokenizer_quantizers.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from .ar_tokenizer_utils import default, pack_one, round_ste, unpack_one + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) diff --git a/ar_tokenizer_text_tokenizer.py b/ar_tokenizer_text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9918ab7cc8f55dc0c159b58c158d3556b6819acd --- /dev/null +++ b/ar_tokenizer_text_tokenizer.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoTokenizer + +from .log import log + + +def get_tokenizer_path(model_family: str, is_instruct_model: bool = False): + """ + Get the tokenizer path from the model family and instruct model flag. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + Returns: + str: The tokenizer path in s3. + """ + model_family = model_family.lower() + if model_family == "mistral": + return "mistralai/Mistral-Nemo-Instruct-2407" + else: + assert model_family in ["llama3", "llama3.1"] + if model_family == "llama3": + model_path = "meta-llama/Meta-Llama-3-8B" + elif model_family == "llama3.1": + model_path = "meta-llama/Llama-3.1-8B" + else: + raise ValueError(f"Unsupported model family: {model_family}") + suffix = "-Instruct" if is_instruct_model else "" + model_path = f"{model_path}{suffix}" + return model_path + + +class TextTokenizer: + """ + Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based). + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + local_path: Optional[str] = None, + ): + """ + Initialize the TextTokenizer. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path. + """ + if local_path is None: + tokenizer_path = get_tokenizer_path(model_family, is_instruct_model) + else: + tokenizer_path = local_path + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + self.stop_tokens = { + self.tokenizer.eos_token_id, + } + self.model_family = model_family + self.is_instruct_model = is_instruct_model + self.eos_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + if model_family.startswith("llama"): + self.pad_id = 128004 # "<|finetune_right_pad_id|>" + elif model_family == "mistral": + self.pad_id = 10 # "" + elif model_family == "pixtral": + self.pad_id = 11 # "" + else: + raise ValueError(f"pad_id not defined for model_family {model_family}") + else: + self.pad_id = self.tokenizer.pad_token_id + + def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]: + """ + Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. + + Args: + text (`str`): + The sequence to be encoded. + add_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add the special tokens associated with the corresponding model. + Returns: + `List[str]`: The list of tokens. + """ + return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[str] = None, + **kwargs, + ) -> List[int]: + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Args: + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding the sequences. This will use the underlying + `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are + automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens + automatically. + padding (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + return self.tokenizer.encode( + text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + ) + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"], + *, # Enforce keyword-only arguments + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + return self.tokenizer.decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + *, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = False, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting. + + More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template + + Args: + conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts + with "role" and "content" keys, representing the chat history so far. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, defaults to `False`): + Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + + Returns: + `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. + """ + if not self.is_instruct_model: + raise ValueError( + "apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor." + ) + # Since generation_prefix is added to the text in the end, ensure that the setting is correct + if generation_prefix: + assert not tokenize, "tokenize must be False when generation_prefix is provided." + assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided." + formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=return_dict, + return_assistant_tokens_mask=return_assistant_tokens_mask, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + if generation_prefix: + formatted_text: str = formatted_text + generation_prefix + log.debug( + f"Adding generation prefix: {generation_prefix} to the formatted text\n" + f"Formatted text: {formatted_text}" + ) + return formatted_text diff --git a/ar_tokenizer_utils.py b/ar_tokenizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dd58c7830e60e5a09a38b991ccb5fef3b13293 --- /dev/null +++ b/ar_tokenizer_utils.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() diff --git a/ar_transformer.py b/ar_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d05a0444568cf25b10c7422cfec1268cab944f81 --- /dev/null +++ b/ar_transformer.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from torch.nn.modules.module import _IncompatibleKeys + +from .ar_modules_attention import Attention +from .ar_modules_embedding import ( + RotaryPositionEmbeddingPytorchV1, + RotaryPositionEmbeddingPytorchV2, + SinCosPosEmbAxisTE, +) +from .ar_modules_mlp import MLP +from .ar_modules_normalization import create_norm +from .checkpoint import process_state_dict, substrings_to_ignore +from .ar_utils_misc import maybe_convert_to_namespace +from .log import log + + +class TransformerBlock(nn.Module): + """ + A single transformer block consisting of an attention layer and a feed-forward layer. + """ + + def __init__(self, layer_id: int, args=None): + """ + Initializes the TransformerBlock module. + + Args: + layer_id: The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + super().__init__() + args = maybe_convert_to_namespace(args) + attention_args = { + "n_heads": args["n_heads"], + "n_kv_heads": args["n_kv_heads"], + "dim": args["dim"], + "context_dim": None, + "max_batch_size": args["max_batch_size"], + "max_seq_len": args["max_seq_len"], + "use_qk_normalization": args["use_qk_normalization"], + "causal_mask": args["causal_mask"], + "head_dim": args["head_dim"], + "fuse_qkv": getattr(args, "fuse_qkv", False), + "precision": getattr(args, "precision", "bfloat16"), + "attn_type": getattr(args, "attn_type", "self"), + } + self.attention = Attention(**attention_args) + + self.has_cross_attention = False + self.cross_attention, self.cross_attention_norm = None, None + + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + cross_attention_args = attention_args.copy() + cross_attention_args.update({"context_dim": args["context_dim"], "fuse_qkv": False, "attn_type": "cross"}) + self.cross_attention = Attention(**cross_attention_args) + self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + self.feed_forward = MLP( + dim=args["dim"], + hidden_dim=args["ffn_hidden_size"], + ) + self.layer_id = layer_id + self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbeddingPytorchV2, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the TransformerBlock module. + + Args: + x: The input tensor. + input_pos: The position of the current sequence. Used in inference (with KV cache) only. + freqs_cis: The precomputed frequency values for rotary position embeddings. + mask: The attention mask tensor. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + The output tensor after applying the transformer block. + """ + # Apply attention and residual connection + h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) + + # If insert cross-attention, apply CA and residual connection + if self.has_cross_attention: + h = h + self.cross_attention( + self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context + ) + + # Apply feed-forward network and residual connection + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + if self.has_cross_attention: + self.cross_attention_norm.reset_parameters() + self.cross_attention.init_weights(self.weight_init_std) + # zero-init the final output layer of cross-attention + # nn.init.zeros_(self.cross_attention.wo.weight) + + +class Transformer(nn.Module): + """ + The Transformer network consisting of transformer blocks. + """ + + def __init__(self, params, tokenizer_config=None, init_weights: bool = True): + """ + Initializes the Transformer module. + + Args: + params: The model parameters containing hyperparameters. + tokenizer_config: The model tokenizer configuration. + init_weights (bool): Whether to initialize the weights of the transformer following + TorchTitan's Llama3 initialization scheme. + """ + super().__init__() + # Check if self.params is an OmegaConf DictConfig instance + self.params = maybe_convert_to_namespace(params) + self.vocab_size = params["vocab_size"] + self.n_layers = params["n_layers"] + self.precision = getattr(torch, params["precision"]) + self.tokenizer_config = tokenizer_config + self.num_video_frames = params["num_video_frames"] + + # Token embeddings + self.tok_embeddings = self._create_token_embeddings() + self.rope_config = self._create_rope_config() + + # Transformer layers + self.layers = nn.ModuleList( + [TransformerBlock(layer_id, self.params).to(self.precision) for layer_id in range(self.n_layers)] + ) + + # Final layer normalization + self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( + self.precision + ) + if self.params["pytorch_rope_version"] == "v1": + self.rope = RotaryPositionEmbeddingPytorchV1(**self.rope_config) + elif self.params["pytorch_rope_version"] == "v2": + # Rotary position embeddings + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + self.rope = RotaryPositionEmbeddingPytorchV2( + seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config + ) + else: + raise ValueError(f"Invalid PyTorch RoPE version: {self.params['pytorch_rope_version']}") + # Causal mask + self.causal_mask = torch.tril( + torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) + ).cuda() + + # Output projection + self.output = self._create_output_projection() + + # Freeze network parameters for finetuning w/ cross-attention + self.has_cross_attention = getattr(params, "insert_cross_attn", False) + + # Absolute position embeddings + if self.params["apply_abs_pos_emb"]: + self.pos_emb_config = self._create_abs_pos_emb_config() + self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() + + def _create_rope_config(self) -> Dict: + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + head_dim = self.params["head_dim"] + if head_dim is None: + head_dim = self.params["dim"] // self.params["n_heads"] + return { + "dim": head_dim, + "max_position_embeddings": self.params["max_seq_len"], + "original_max_position_embeddings": self.params["original_seq_len"], + "rope_theta": self.params["rope_theta"], + "apply_yarn": self.params["apply_yarn"], + "scale": self.params["yarn_scale"], + "beta_fast": self.params["yarn_beta_fast"], + "beta_slow": self.params["yarn_beta_slow"], + "rope_dim": self.params["rope_dim"], + "latent_shape": latent_shape, + "original_latent_shape": self.params["original_latent_shape"], + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_abs_pos_emb_config(self): + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + return { + "dim": self.params["dim"], + "latent_shape": latent_shape, + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_token_embeddings(self, vocab_size: int = None): + """ + Create token embeddings. + + Returns: + nn.Module: Token embeddings module. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) + + def _create_output_projection(self, vocab_size: int = None): + """ + Create the output projection layer. + + Args: + vocab_size (int): Vocabulary size (to override the default vocab size). + Returns: + LinearTE: Output projection layer. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision) + + def _initialize_abs_pos_emb(self): + pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + abs_pos_emb = pos_emb.forward(training_type=training_type) + return pos_emb, abs_pos_emb + + def forward( + self, + tokens: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the Transformer module. + + Args: + tokens (torch.Tensor, optional): The input tensor of token IDs. + input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. + token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + The output tensor after applying the transformer layers. + """ + # Token embeddings + assert ( + tokens is None or token_embeddings is None + ), "Either tokens or token_embeddings should be provided, not both." + + if token_embeddings is None: + seq_len = tokens.shape[1] + h = self.tok_embeddings(tokens) + else: + seq_len = token_embeddings.shape[1] + h = token_embeddings + + # Create attention mask + mask = self._create_attention_mask(input_pos=input_pos) + + # Prepare layer arguments + layer_kwargs = self._prepare_layer_kwargs( + input_pos=input_pos, + mask=mask, + context=context, + context_mask=context_mask, + ) + + # Apply transformer layers + for layer in self.layers: + if self.params["apply_abs_pos_emb"]: + h = self.apply_abs_pos_emb(h, input_pos=input_pos) + h = layer(h, **layer_kwargs) + + # Apply final layer normalization + h = self.norm(h) + + # Output linear projection + output = self.output(h) + return output + + def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """ + Creates an attention mask for the transformer layers. + + Args: + input_pos[torch.Tensor]: The position of input sequence (used for inference only). + + Returns: + Optional[torch.Tensor]: The attention mask, or None for causal mask. + """ + + assert input_pos is not None, "input_pos must be provided for inference" + mask = self.causal_mask[input_pos] + return mask + + def _prepare_layer_kwargs( + self, + input_pos: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + context: Optional[torch.Tensor], + context_mask: Optional[torch.Tensor], + ) -> Dict[str, Any]: + """ + Prepares the keyword arguments for transformer layers. + + Args: + input_pos (Optional[torch.Tensor]): The position of the current sequence. + mask (Optional[torch.Tensor]): The attention mask. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. + """ + if context is not None: + context = context.to(self.precision) + + if isinstance(mask, torch.Tensor) and mask.ndim == 2: + mask = mask[None, None, :, :] + if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: + context_mask = context_mask[None, None, :, :] + + layer_kwargs = { + "mask": mask, + "context": context, + "context_mask": context_mask, + } + + layer_kwargs["input_pos"] = input_pos + layer_kwargs["rope"] = self.rope + + return layer_kwargs + + def apply_abs_pos_emb(self, x: torch.Tensor, input_pos: int = None) -> torch.Tensor: + """ + Applies the absolute position embeddings to the input tensor. + """ + abs_pos_emb = self.abs_pos_emb + abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb + return x + abs_pos_emb + + @torch.no_grad() + def expand_vocab( + self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True + ): + """ + Expands the vocabulary of the model to the new size. + + Args: + new_vocab_size (int): The new vocabulary size. + init_method (str): The initialization method for new embeddings. + Can be "zero" or "gaussian". Default is "gaussian". + multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully + leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, + source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) + expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. + + Returns: + None + """ + if new_vocab_size <= self.vocab_size: + raise ValueError( + f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" + ) + if new_vocab_size % multiple_of != 0: + log.debug(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") + new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of + log.debug(f"Rounded vocabulary size to {new_vocab_size}.") + # Resize token embeddings + old_embeddings = self.tok_embeddings + tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} + self.tok_embeddings = self._create_token_embeddings(vocab_size=new_vocab_size).to(**tensor_kwargs) + # Initialize new embeddings + if init_method not in ["zero", "gaussian"]: + raise ValueError(f"Unknown initialization method: {init_method}") + # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything + # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. + if init_method == "zero": + self.tok_embeddings.weight.data[self.vocab_size :].zero_() + + # Copy old embeddings + log.debug( + f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" + ) + self.tok_embeddings.weight.data[: self.vocab_size] = old_embeddings.weight.data + # Resize output layer + old_output = self.output + self.output = self._create_output_projection(vocab_size=new_vocab_size if expand_output_layer else None) + + # Initialize new output weights + self.output.weight.data[self.vocab_size :].zero_() + # Copy old output weights + self.output.weight.data[: self.vocab_size] = old_output.weight.data + + # Update vocab size + self.vocab_size = new_vocab_size + log.debug(f"Expanded vocabulary size to {new_vocab_size}") + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + if strict: + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + missing_keys = actual_missing_keys + return _IncompatibleKeys(missing_keys, unexpected_keys) diff --git a/ar_utils_misc.py b/ar_utils_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..602ea1cb383d8263be06829a466cfb3ba9f97856 --- /dev/null +++ b/ar_utils_misc.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf import DictConfig, OmegaConf + + +class CustomSimpleNamespace: + """ + A simple namespace class that supports both attribute-style and dictionary-style access. + """ + + def __init__(self, d): + self._d = d + + def __getattr__(self, attr): + # Attribute-style access: config.key + try: + return self._d[attr] + except KeyError: + raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'") + + def __getitem__(self, key): + # Dictionary-style access: config['key'] + return self._d[key] + + +def maybe_convert_to_namespace(config): + """ + This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both + attribute-style and dictionary-style access. + Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile. + """ + # If input is OmegaConf's DictConfig, convert to a standard dict + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + + if isinstance(config, dict): + return CustomSimpleNamespace(config) + else: + return config diff --git a/attention.py b/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2464bc5e1892a3541ce439c0ea36347f43647224 --- /dev/null +++ b/attention.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import numpy as np +import torch +import transformer_engine as te +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import checkpoint +from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb + +# ---------------------- Feed Forward Network ----------------------- + + +class FeedForward(nn.Module): + """ + Transformer FFN with optional gating + + Parameters: + d_model (int): Dimensionality of input features. + d_ff (int): Dimensionality of the hidden layer. + dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. + activation (callable, optional): The activation function applied after the first linear layer. + Defaults to nn.ReLU(). + is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. + Defaults to False. + bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. + + Example: + >>> ff = FeedForward(d_model=512, d_ff=2048) + >>> x = torch.randn(64, 10, 512) # Example input tensor + >>> output = ff(x) + >>> print(output.shape) # Expected shape: (64, 10, 512) + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + activation=nn.ReLU(), + is_gated: bool = False, + bias: bool = False, + ) -> None: + super().__init__() + + self.layer1 = nn.Linear(d_model, d_ff, bias=bias) + self.layer2 = nn.Linear(d_ff, d_model, bias=bias) + + self.dropout = nn.Dropout(dropout) + self.activation = activation + self.is_gated = is_gated + if is_gated: + self.linear_gate = nn.Linear(d_model, d_ff, bias=False) + + def forward(self, x: torch.Tensor): + g = self.activation(self.layer1(x)) + if self.is_gated: + x = g * self.linear_gate(x) + else: + x = g + assert self.dropout.p == 0.0, "we skip dropout" + return self.layer2(x) + + +class GPT2FeedForward(FeedForward): + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): + super().__init__( + d_model=d_model, + d_ff=d_ff, + dropout=dropout, + activation=nn.GELU(), + is_gated=False, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + assert self.dropout.p == 0.0, "we skip dropout" + + x = self.layer1(x) + + def activation_layer2_forward(x): + x = self.activation(x) + x = self.layer2(x) + return x + + x = checkpoint(activation_layer2_forward, x, use_reentrant=False) + return x + + +# ---------------------- Normalization Layer ----------------------- + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def get_normalization(name: str, channels: int): + if name == "I": + return nn.Identity() + elif name == "R": + return te.pytorch.RMSNorm(channels, eps=1e-6) + else: + raise ValueError(f"Normalization {name} not found") + + +class BaseAttentionOp(nn.Module): + def __init__(self): + super().__init__() + + +class Attention(nn.Module): + """ + Generalized attention impl. + + Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. + If `context_dim` is None, self-attention is assumed. + + Parameters: + query_dim (int): Dimension of each query vector. + context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. + heads (int, optional): Number of attention heads. Defaults to 8. + dim_head (int, optional): Dimension of each head. Defaults to 64. + dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. + attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. + qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. + out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. + qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. + Defaults to "SSI". + qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. + Defaults to 'per_head'. Only support 'per_head'. + + Examples: + >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) + >>> query = torch.randn(10, 128) # Batch size of 10 + >>> context = torch.randn(10, 256) # Batch size of 10 + >>> output = attn(query, context) # Perform the attention operation + + Note: + https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + """ + + def __init__( + self, + query_dim: int, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_op: Optional[BaseAttentionOp] = None, + qkv_bias: bool = False, + out_bias: bool = False, + qkv_norm: str = "SSI", + qkv_norm_mode: str = "per_head", + backend: str = "transformer_engine", + qkv_format: str = "bshd", + ) -> None: + super().__init__() + + self.is_selfattn = context_dim is None # self attention + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + self.qkv_norm_mode = qkv_norm_mode + self.qkv_format = qkv_format + + if self.qkv_norm_mode == "per_head": + norm_dim = dim_head + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + self.backend = backend + + self.to_q = nn.Sequential( + nn.Linear(query_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[0], norm_dim), + ) + self.to_k = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[1], norm_dim), + ) + self.to_v = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[2], norm_dim), + ) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim, bias=out_bias), + nn.Dropout(dropout), + ) + + if attn_op: # use what is given + self.attn_op = attn_op + elif self.backend == "transformer_engine": + sequence_parallel = False + self.attn_op: BaseAttentionOp = DotProductAttention( + self.heads, + self.dim_head, + num_gqa_groups=self.heads, + attention_dropout=0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + tp_size=1, + tp_group=None, + sequence_parallel=sequence_parallel, + ) + else: + raise ValueError(f"Backend {backend} not found") + + def cal_qkv( + self, x, context=None, mask=None, rope_emb=None, **kwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + + """ + self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. + Before 07/24/2024, these modules normalize across all heads. + After 07/24/2024, to support tensor parallelism and follow the common practice in the community, + we support to normalize per head. + To keep the checkpoint copatibility with the previous code, + we keep the nn.Sequential but call the projection and the normalization layers separately. + We use a flag `self.qkv_norm_mode` to control the normalization behavior. + The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. + """ + if self.qkv_norm_mode == "per_head": + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + q, k, v = map( + lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head), + (q, k, v), + ) + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) + return q, k, v + + def cal_attn(self, q, k, v, mask=None): + if self.backend == "transformer_engine": + seq_dim = self.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] + return self.to_out(out) + elif self.backend == "torch": + out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] + return self.to_out(rearrange(out, " b ... n c -> b ... (n c)")) + else: + raise ValueError(f"Backend {self.backend} not found") + + def forward( + self, + x, + context=None, + mask=None, + rope_emb=None, + **kwargs, + ): + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) + return self.cal_attn(q, k, v, mask) diff --git a/base_world_generation_pipeline.py b/base_world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f06d31e5df3403347d7903d0633cf191443599 --- /dev/null +++ b/base_world_generation_pipeline.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from abc import ABC +from typing import Any + +import numpy as np +import torch + +from .t5_text_encoder import CosmosT5TextEncoder +from .presets import presets as guardrail_presets + + +class BaseWorldGenerationPipeline(ABC): + def __init__( + self, + inference_type: str | None = None, + checkpoint_dir: str | None = None, + checkpoint_name: str | None = None, + enable_text_guardrail: bool = False, + enable_video_guardrail: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_guardrail_models: bool = False, + ): + """Initialize base world generation pipeline. + + This abstract base class provides core functionality for world generation models including: + - Model loading and initialization + - Text encoding and embedding + - Safety checks and content filtering + - Memory management through model offloading + + Args: + inference_type: The type of inference pipeline ("text2world" or "video2world") + checkpoint_dir: Root directory containing model checkpoints + checkpoint_name: Name of the specific checkpoint file to load + enable_text_guardrail: If True, validates input prompts for safety + enable_video_guardrail: If True, validates generated videos for safety + offload_network: If True, moves main model to CPU after inference + offload_tokenizer: If True, moves tokenizer to CPU after use + offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding + offload_guardrail_models: If True, moves safety models to CPU after checks + """ + self.inference_type = inference_type + self.checkpoint_dir = checkpoint_dir + self.checkpoint_name = checkpoint_name + self.guardrail_dir = "Cosmos-1.0-Guardrail" + self.enable_text_guardrail = enable_text_guardrail + self.enable_video_guardrail = enable_video_guardrail + + # Add offloading flags + self.offload_network = offload_network + self.offload_tokenizer = offload_tokenizer + self.offload_text_encoder_model = offload_text_encoder_model + self.offload_guardrail_models = offload_guardrail_models + + # Initialize model instances + self.text_guardrail = None + self.video_guardrail = None + self.text_encoder = None + self.model = None + + self._load_model() + + if not self.offload_text_encoder_model: + self._load_text_encoder_model() + if not self.offload_guardrail_models: + if self.enable_text_guardrail: + self._load_text_guardrail() + if self.enable_video_guardrail: + self._load_video_guardrail() + if not self.offload_network: + self._load_network() + if not self.offload_tokenizer: + self._load_tokenizer() + + def _load_tokenizer(self): + pass + + def _load_network(self): + pass + + def _load_model(self, checkpoint_name: str) -> Any: + """Load the world generation model from a checkpoint. + + This abstract method must be implemented by subclasses to load their specific + model architecture and weights. + + Args: + checkpoint_name: Path to the model checkpoint file + + Returns: + The loaded model instance + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + pass + + def _load_text_encoder_model(self): + """Load the T5 text encoder model. + + Initializes and loads the T5 encoder model used for converting text prompts + into embeddings that condition the world generation model. + + Returns: + Loaded T5 text encoder model instance + """ + self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir) + + def _load_text_guardrail(self): + """Load text safety classifier models. + + Initializes models used for checking input prompts against safety policies. + Models are loaded from the specified guardrail directory. + """ + self.text_guardrail = guardrail_presets.create_text_guardrail_runner( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) + ) + + def _load_video_guardrail(self): + """Load video safety classifier models. + + Initializes models used for validating generated video content against + safety policies. Models are loaded from the specified guardrail directory. + """ + self.video_guardrail = guardrail_presets.create_video_guardrail_runner( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) + ) + + def _offload_network(self): + if self.model.model: + del self.model.model + self.model.model = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_tokenizer(self): + if self.model.tokenizer: + del self.model.tokenizer + self.model.tokenizer = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_guardrail_models(self): + """Offload safety classifier models to reduce memory usage. + + Moves safety models to CPU and clears GPU memory if they are no longer needed. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_guardrail: + del self.text_guardrail + self.text_guardrail = None + if self.video_guardrail: + del self.video_guardrail + self.video_guardrail = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_text_encoder_model(self): + """Offload T5 text encoder to reduce memory usage. + + Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_encoder: + del self.text_encoder + self.text_encoder = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world latents using the model. + + This abstract method must be implemented by subclasses to define their specific + generation process. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + + Returns: + torch.Tensor: Generated world representation tensor + """ + pass + + def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world representation with memory management. + + Handles loading the model before inference and offloading afterward if enabled. + This helps minimize GPU memory usage during inference. + + Args: + *args: Arguments passed to _run_model + **kwargs: Keyword arguments passed to _run_model + + Returns: + np.ndarray: Generated world representation as numpy array + """ + pass + + def _run_guardrail_on_prompt(self, prompt: str) -> bool: + """Check if prompt meets safety requirements. + + Validates the input prompt against safety policies using loaded guardrail models. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) + + def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: + """Check prompt safety with memory management. + + Validates prompt safety while handling model loading/offloading to manage memory. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + if self.offload_guardrail_models: + self._load_text_guardrail() + + is_safe = self._run_guardrail_on_prompt(prompt) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + + return is_safe + + def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: + """Check if video meets safety requirements. + + Validates generated video content against safety policies using guardrail models. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video if safe, None if unsafe + """ + return guardrail_presets.run_video_guardrail(video, self.video_guardrail) + + def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: + """Check if generated video meets safety requirements. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video frames if safe, None otherwise + + Note: + Guardrail models are offloaded after checks if enabled. + """ + if self.offload_guardrail_models: + self._load_video_guardrail() + + video = self._run_guardrail_on_video(video) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + return video + + def _run_text_embedding_on_prompt( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompts to embeddings. + + Processes text prompts into embedding tensors that condition the generation model. + + Args: + prompts: List of text prompts to encode + **kwargs: Additional arguments for text encoding + + Returns: + tuple containing: + - List of text embedding tensors for each prompt + - List of attention masks for each embedding + """ + + embeddings = [] + masks = [] + for prompt in prompts: + embedding, mask = self.text_encoder.encode_prompts( + [prompt], + **kwargs, + ) + embeddings.append(embedding) + masks.append(mask) + + return embeddings, masks + + def _run_text_embedding_on_prompt_with_offload( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompt into embeddings using T5 encoder. + + Args: + prompt: Processed and validated text prompt + + Returns: + Text embedding tensor to condition diffusion model + + Note: + T5 model is offloaded after encoding if enabled. + """ + if self.offload_text_encoder_model: + self._load_text_encoder_model() + + embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) + + if self.offload_text_encoder_model: + self._offload_text_encoder_model() + return embeddings, masks + + def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: + """Decode model outputs into final world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + samples: Raw output tensor from the generation model + + Returns: + np.ndarray: Decoded world representation + """ + pass + + def generate(self, *args: Any, **kwargs: Any): + """Generate world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + """ + pass diff --git a/batch_ops.py b/batch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..09440b34a95b1708d2154376f2a0202a533cb3b2 --- /dev/null +++ b/batch_ops.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Functions for performing operations with broadcasting to the right axis +# +# Example +# input1: tensor of size (N1, N2) +# input2: tensor of size (N1, N2, N3, N4) +# batch_mul(input1, input2) = input1[:, :, None, None] * input2 +# +# If the common dimensions don't match, we raise an assertion error. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x * y diff --git a/blocklist.py b/blocklist.py new file mode 100644 index 0000000000000000000000000000000000000000..f7da159e5bf6c76a16d4c9c7488d7905cd6a0e48 --- /dev/null +++ b/blocklist.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +import string +from difflib import SequenceMatcher + +from .log import log +import nltk +from better_profanity import profanity + +from .guardrail_blocklist_utils import read_keyword_list_from_dir, to_ascii +from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner +from .misc import misc, Color, timer + +DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/blocklist" +CENSOR = Color.red("*") + + +class Blocklist(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR, + guardrail_partial_match_min_chars: int = 4, + guardrail_partial_match_letter_count: float = 0.5, + ) -> None: + nltk.data.path.append(os.path.join(checkpoint_dir, "nltk_data")) + self.lemmatizer = nltk.WordNetLemmatizer() + self.profanity = profanity + self.checkpoint_dir = checkpoint_dir + self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars + self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count + + # Load blocklist and whitelist keywords + self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom")) + self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist")) + self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match")) + + self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words) + log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist") + log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist") + log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist") + + def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str: + """Explicitly uncensor words that are in the whitelist.""" + input_words = input_prompt.split() + censored_words = censored_prompt.split() + whitelist_words = set(self.whitelist_words) + for i, token in enumerate(input_words): + if token.strip(string.punctuation).lower() in whitelist_words: + censored_words[i] = token + censored_prompt = " ".join(censored_words) + return censored_prompt + + def censor_prompt(self, input_prompt: str) -> tuple[bool, str]: + """Censor the prompt using the blocklist with better-profanity fuzzy matching. + + Args: + input_prompt: input prompt to censor + + Returns: + bool: True if the prompt is blocked, False otherwise + str: A message indicating why the prompt was blocked + """ + censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR) + # Uncensor whitelisted words that were censored from blocklist fuzzy matching + censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt) + if CENSOR in censored_prompt: + return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}" + return False, "" + + @staticmethod + def check_partial_match( + normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float + ) -> tuple[bool, str]: + """ + Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters. + + Args: + normalized_prompt: a string with many words + normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt + guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters) + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + prompt_words = normalized_prompt.split() + word_length = len(normalized_word.split()) + max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float( + len(normalized_word) + ) + + for i in range(len(prompt_words) - word_length + 1): + # Extract a substring from the prompt with the same number of words as the normalized_word + substring = " ".join(prompt_words[i : i + word_length]) + similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio() + if similarity_ratio >= max_similarity_ratio: + return ( + True, + f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}", + ) + + return False, "" + + @staticmethod + def check_against_whole_word_blocklist( + prompt: str, + blocklist: list[str], + guardrail_partial_match_min_chars: int = 4, + guardrail_partial_match_letter_count: float = 0.5, + ) -> bool: + """ + Check if the prompt contains any whole words from the blocklist. + The match is case insensitive and robust to multiple spaces between words. + + Args: + prompt: input prompt to check + blocklist: list of words to check against + guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match + guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + # Normalize spaces and convert to lowercase + normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower() + + for word in blocklist: + # Normalize spaces and convert to lowercase for each blocklist word + normalized_word = re.sub(r"\s+", " ", word).strip().lower() + + # Use word boundaries to ensure whole word match + if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt): + return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}" + + # Check for partial match if the word is long enough + if len(normalized_word) >= guardrail_partial_match_min_chars: + match, message = Blocklist.check_partial_match( + normalized_prompt, normalized_word, guardrail_partial_match_letter_count + ) + if match: + return True, message + + return False, "" + + def is_safe(self, input_prompt: str = "") -> tuple[bool, str]: + """Check if the input prompt is safe using the blocklist.""" + # Check if the input is empty + if not input_prompt: + return False, "Input is empty" + input_prompt = to_ascii(input_prompt) + + # Check full sentence for censored words + censored, message = self.censor_prompt(input_prompt) + if censored: + return False, message + + # Check lemmatized words for censored words + tokens = nltk.word_tokenize(input_prompt) + lemmas = [self.lemmatizer.lemmatize(token) for token in tokens] + lemmatized_prompt = " ".join(lemmas) + censored, message = self.censor_prompt(lemmatized_prompt) + if censored: + return False, message + + # Check for exact match blocklist words + censored, message = self.check_against_whole_word_blocklist( + input_prompt, + self.exact_match_words, + self.guardrail_partial_match_min_chars, + self.guardrail_partial_match_letter_count, + ) + if censored: + return False, message + + # If all these checks pass, the input is safe + return True, "Input is safe" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Blocklist checkpoint folder", + default=DEFAULT_CHECKPOINT_DIR, + ) + return parser.parse_args() + + +def main(args): + blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[blocklist]) + with timer("blocklist safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/blocks.py b/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2c1180e5c0845bdfe393db2a0caccac0559a9d --- /dev/null +++ b/blocks.py @@ -0,0 +1,545 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import nn + +from .attention import Attention, GPT2FeedForward +from .log import log + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class Timesteps(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +class FourierFeatures(nn.Module): + """ + Implements a layer that generates Fourier features from input tensors, based on randomly sampled + frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. + + [B] -> [B, D] + + Parameters: + num_channels (int): The number of Fourier features to generate. + bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. + normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize + the variance of the features. Defaults to False. + + Example: + >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) + >>> x = torch.randn(10, 256) # Example input tensor + >>> output = layer(x) + >>> print(output.shape) # Expected shape: (10, 256) + """ + + def __init__(self, num_channels, bandwidth=1, normalize=False): + super().__init__() + self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) + self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) + self.gain = np.sqrt(2) if normalize else 1 + + def forward(self, x, gain: float = 1.0): + """ + Apply the Fourier feature transformation to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. + + Returns: + torch.Tensor: The transformed tensor, with Fourier features applied. + """ + in_dtype = x.dtype + x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) + x = x.cos().mul(self.gain * gain).to(in_dtype) + return x + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module processes video features while maintaining their spatio-temporal structure. It can perform + self-attention within the video features or cross-attention with external context features. + + Parameters: + x_dim (int): Dimension of input feature vectors + context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention + num_heads (int): Number of attention heads + bias (bool): Whether to include bias in attention projections. Default: False + qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head" + x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD" + + Input shape: + - x: (T, H, W, B, D) video features + - context (optional): (M, B, D) context features for cross-attention + where: + T: temporal dimension + H: height + W: width + B: batch size + D: feature dimension + M: context sequence length + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + ) -> None: + super().__init__() + self.x_format = x_format + + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_norm_mode=qkv_norm_mode, + qkv_format="sbhd", + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), + where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + return x_T_H_W_B_D + + +def adaln_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + A building block for the DiT (Diffusion Transformer) architecture that supports different types of + attention and MLP operations with adaptive layer normalization. + + Parameters: + block_type (str): Type of block - one of: + - "cross_attn"/"ca": Cross-attention + - "full_attn"/"fa": Full self-attention + - "mlp"/"ff": MLP/feedforward block + x_dim (int): Dimension of input features + context_dim (Optional[int]): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + bias (bool): Whether to use bias in layers. Default: False + mlp_dropout (float): Dropout rate for MLP. Default: 0.0 + qkv_norm_mode (str): QKV normalization mode. Default: "per_head" + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + bias: bool = False, + mlp_dropout: float = 0.0, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + qkv_norm_mode=qkv_norm_mode, + x_format=self.x_format, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn( + x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format + ) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer. + Each block in the sequence is specified by a block configuration string. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention blocks + num_heads (int): Number of attention heads + block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention, + full-attention, then MLP) + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + + The block_config string uses "-" to separate block types: + - "ca"/"cross_attn": Cross-attention block + - "fa"/"full_attn": Full self-attention block + - "mlp"/"ff": MLP/feedforward block + + Example: + block_config = "ca-fa-mlp" creates a sequence of: + 1. Cross-attention block + 2. Full self-attention block + 3. MLP block + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x diff --git a/blur_utils.py b/blur_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d52f69d220444a53027b3b4acc3bd192fc6eb76f --- /dev/null +++ b/blur_utils.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: + """ + Pixelate a face region by reducing resolution and then upscaling. + + Args: + face_img: Face region to pixelate + blocks: Number of blocks to divide the face into (in each dimension) + + Returns: + Pixelated face region + """ + h, w = face_img.shape[:2] + # Shrink the image and scale back up to create pixelation effect + temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) + pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) + return pixelated diff --git a/categories.py b/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..f9236da33c8d2669b701cc21709925bacbbe6907 --- /dev/null +++ b/categories.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +VIOLENCE = { + "category": "violence", + "blocked_reason": """This prompt is blocked due to the presence of violent content.""", + "instruction": """O1: Violence. +Should not +- Help people plan or engage in violence. +- Encourage people to engage in violence. +- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader. +Can +- Provide information on violence. +- Discuss the topics of violence, and can discuss historical events involving violence. +""", +} + +SEXUAL = { + "category": "sexual", + "blocked_reason": """This prompt is blocked due to the presence of sexual content.""", + "instruction": """O2: Sexual. +Should not +- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +- Engage or encourage sexual activity that is considered illegal. +Can +- Discuss topics related to sexuality and sexual education. +- Provide general information regarding sexual activities. +- Provide guidance on topics related to sexual health. +""", +} + +CRIMINAL_PLANNING = { + "category": "criminal_planning", + "blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""", + "instruction": """O3: Criminal Planning/Confessions. +Should not +- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in criminal activities or condone criminal behavior. +- Have a confession of crime that was committed or will be committed in future. +Can +- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. +- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. +""", +} + +WEAPONS = { + "category": "weapons", + "blocked_reason": """This prompt is blocked due to the presence of weapons content.""", + "instruction": """O4: Guns and Illegal Weapons. +Should not +- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them. +Can +- Discuss firearms and the arguments for and against firearm ownership. +""", +} + +SUBSTANCE_ABUSE = { + "category": "substance_abuse", + "blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""", + "instruction": """O5: Controlled/Regulated Substances. +Should not +- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. +- Assist or encourage people to create such substances when it is illegal to do so. +Can +- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). +- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. +""", +} + +SUICIDE = { + "category": "suicide", + "blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.", + "instruction": """O6: Suicide and Self Harm. +Should not +- Encourage people to harm themselves. +- Romanticize or condone self-harm or suicide. +- Provide information regarding the methods of suicide or self-harm. +- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +- Contain confessions or intent of self harm +Can +- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should +- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. +""", +} + +CSAM = { + "category": "csam", + "blocked_reason": "This prompt contains material related to Child Sexual Abuse", + "instruction": """O7: Sexual (minor). +Should not +- Assist, engage or encourage intention to engage in sexual activity with an underage/minor. +- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor. +- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media. +- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children. +Can +- Provide information about sex education for use in schools. +""", +} + +HATRED = { + "category": "hatred", + "blocked_reason": "This prompt is blocked due to the presence of hate speech content.", + "instruction": """O8: Hate/Identity Hate. +Should not +- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their +race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste. +- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country. +- Should not contain slurs. +- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear. +Can +- Provide information on discrimination. +- Discuss the topics of discrimination, and can discuss historical events involving discrimination. +- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate). +""", +} + +HARASSMENT = { + "category": "harassment", + "blocked_reason": "This prompt is blocked due to the presence of harassment content.", + "instruction": """O10: Harassment. +Should not +- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur. +- Engage in harassment, misconduct and unlawful discrimination +- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex, +Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled), +Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation, +- Taking or requesting statutorily protected leave Body characteristics. +- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct. +- Bullying. +- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law. +- Making rude or insulting comments, demeaning, and objectifying terms toward an individual. +Can +- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting. +- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are. +- Provide information or refer or cite material that informs what harassment is. +""", +} + +THREAT = { + "category": "threat", + "blocked_reason": "This prompt is blocked due to the presence of threat content.", + "instruction": """O11: Threat. +Should not +- Discuss a communication of intent to inflict harm or loss on another person. +- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control. +- Should not contain intention of knowingly putting another person in fear of bodily injury. +- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person. +- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will. +Can +- Provide information about threat is and the meaning of threat. +""", +} + +PROFANITY = { + "category": "profanity", + "blocked_reason": "This prompt is blocked due to the presence of profanity content.", + "instruction": """O12: Profanity. +Should not +- Contain swear words, curse words, or other obscene or profane language +- This includes offensive words used without any intention to act on them or without any explicit targeted recipient +""", +} + +UNSAFE_CATEGORIES = [ + VIOLENCE, + SEXUAL, + CRIMINAL_PLANNING, + WEAPONS, + SUBSTANCE_ABUSE, + SUICIDE, + CSAM, + HATRED, + HARASSMENT, + THREAT, + PROFANITY, +] diff --git a/checkpoint.py b/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..00514ee36fc535c00c979a7802492538d9886fae --- /dev/null +++ b/checkpoint.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import torch + +# Substrings to ignore when processing state dicts +substrings_to_ignore = [ + "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling +] + + +def get_partial_state_dict( + state_dict: Dict[str, torch.Tensor], + prefix: str, +) -> Dict[str, torch.Tensor]: + """ + Get a partial state dict with keys starting with the given prefix + """ + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + +def process_state_dict( + state_dict: Dict[str, torch.Tensor], + device: str = None, + dtype: torch.dtype = None, + prefix_to_remove: Optional[str] = None, +) -> Dict[str, torch.Tensor]: + """ + - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) + - Move tensors to specified device and dtype if provided + + Args: + state_dict (Dict[str, torch.Tensor]): The state dict to process + device (str, optional): The device to move tensors to. Defaults to None. + dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. + prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. + + Returns: + Dict[str, torch.Tensor]: The processed state dict + """ + new_state_dict = {} + tensor_kwargs = {} + if device is not None: + tensor_kwargs["device"] = device + if dtype is not None: + tensor_kwargs["dtype"] = dtype + + for key, value in state_dict.items(): + # Check if any of the substrings to ignore are in the key + skip = False + for substr in substrings_to_ignore: + if substr in key: + skip = True + break + if skip: + continue + if len(tensor_kwargs) > 0: + value = value.to(**tensor_kwargs) + if prefix_to_remove is not None and key.startswith(prefix_to_remove): + key = key[len(prefix_to_remove) :] + new_state_dict[key] = value + return new_state_dict diff --git a/conditioner.py b/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..6db3cf9ed06cdcdd30887c47990f04e493d760e8 --- /dev/null +++ b/conditioner.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .batch_ops import batch_mul +from .log import log +from .lazy_config_init import instantiate + + +class BaseConditionEntry(nn.Module): + def __init__(self): + super().__init__() + + self._dropout_rate = None + self._input_key = None + self._return_dict = False + + @property + def dropout_rate(self) -> Union[float, torch.Tensor]: + return self._dropout_rate + + @property + def input_key(self) -> str: + return self._input_key + + @property + def is_return_dict(self) -> bool: + return self._return_dict + + @dropout_rate.setter + def dropout_rate(self, value: Union[float, torch.Tensor]): + self._dropout_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_return_dict.setter + def is_return_dict(self, value: bool): + self._return_dict = value + + @dropout_rate.deleter + def dropout_rate(self): + del self._dropout_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + @is_return_dict.deleter + def is_return_dict(self): + del self._return_dict + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + return batch_mul( + torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), + in_tensor, + ) + + def summary(self) -> str: + pass + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + + +class TextAttr(BaseConditionEntry): + def __init__(self): + super().__init__() + + def forward(self, token: torch.Tensor, mask: torch.Tensor): + return {"crossattn_emb": token, "crossattn_mask": mask} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + if key is not None and "mask" in key: + return in_tensor + return super().random_dropout_input(in_tensor, dropout_rate, key) + + +@dataclass +class BaseVideoCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + data_type: DataType = DataType.VIDEO + padding_mask: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + image_size: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +@dataclass +class VideoExtendCondition(BaseVideoCondition): + video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video + gt_latent: Optional[torch.Tensor] = None + condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region + + # condition_video_input_mask will concat to the input of network, along channel dim; + # Will be concat with the input tensor + condition_video_input_mask: Optional[torch.Tensor] = None + # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + condition_video_augment_sigma: Optional[torch.Tensor] = None + + +class GeneralConditioner(nn.Module, ABC): + """ + An abstract module designed to handle various embedding models with conditional and + unconditional configurations. This abstract base class initializes and manages a collection + of embedders that can dynamically adjust their dropout rates based on conditioning. + + Attributes: + KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. + embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and + configured based on the provided configurations. + + Parameters: + emb_models (Union[List, Any]): A dictionary where keys are embedder names and values + are configurations for initializing the embedders. + + """ + + KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} + + def __init__(self, **emb_models: Union[List, Any]): + super().__init__() + self.embedders = nn.ModuleDict() + for n, (emb_name, embconfig) in enumerate(emb_models.items()): + embedder = instantiate(embconfig.obj) + assert isinstance( + embedder, BaseConditionEntry + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) + + if hasattr(embconfig, "input_key"): + embedder.input_key = embconfig.input_key + elif hasattr(embconfig, "input_keys"): + embedder.input_keys = embconfig.input_keys + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") + self.embedders[emb_name] = embedder + + @abstractmethod + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Any: + """Should be implemented in subclasses to handle conditon datatype""" + raise NotImplementedError + + def _forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Dict: + """ + Processes the input batch through all configured embedders, applying conditional dropout rates if specified. + Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. + + Parameters: + batch (Dict): The input data batch to process. + override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates + per embedder key. + + Returns: + Dict: A dictionary of output tensors concatenated by specified dimensions. + + Note: + In case the network code is sensitive to the order of concatenation, you can either control the order via \ + config file or make sure the embedders return a unique key for each output. + """ + output = defaultdict(list) + if override_dropout_rate is None: + override_dropout_rate = {} + + # make sure emb_name in override_dropout_rate is valid + for emb_name in override_dropout_rate.keys(): + assert emb_name in self.embedders, f"invalid name found {emb_name}" + + for emb_name, embedder in self.embedders.items(): + with torch.no_grad(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + emb_out = embedder( + embedder.random_dropout_input( + batch[embedder.input_key], override_dropout_rate.get(emb_name, None) + ) + ) + elif hasattr(embedder, "input_keys"): + emb_out = embedder( + *[ + embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) + for k in embedder.input_keys + ] + ) + for k, v in emb_out.items(): + output[k].append(v) + # Concatenate the outputs + return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} + + def get_condition_uncondition( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Processes the provided data batch to generate conditioned and unconditioned outputs. + + This method manipulates dropout rates to simulate two scenarios: + 1. All conditions applied (conditioned) + 2. Conditions removed/reduced to minimum (unconditioned) + + This method sets dropout rates to zero for the conditioned scenario to fully apply + embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is + insignificant) to minimize embedder influences. + + Parameters: + data_batch (Dict): Input data batch containing all necessary information for + embedding processing. + + Returns: + Tuple[Any, Any]: A tuple containing: + - Outputs with all embedders fully applied (conditioned) + - Outputs with embedders minimized/not applied (unconditioned) + """ + cond_dropout_rates, dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) + return condition, un_condition + + def get_condition_with_negative_prompt( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Similar functionality as get_condition_uncondition + But use negative prompts for unconditon + """ + cond_dropout_rates, uncond_dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + if isinstance(embedder, TextAttr): + uncond_dropout_rates[emb_name] = 0.0 + else: + uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + data_batch_neg_prompt = copy.deepcopy(data_batch) + if "neg_t5_text_embeddings" in data_batch_neg_prompt: + if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): + data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] + data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) + + return condition, un_condition + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class VideoConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoExtendConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..44614a6daf1e1e29508f6c68eaeaca1b95921fd0 --- /dev/null +++ b/config.json @@ -0,0 +1,10 @@ +{ + "architectures": [ + "DiffusionText2World" + ], + "auto_map": { + "AutoConfig": "text2world_hf.DiffusionText2WorldConfig", + "AutoModel": "text2world_hf.DiffusionText2World" + }, + "model_type": "AutoModel" +} \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1ecf8c749574a0661b3046ad5cd0db0d540099 --- /dev/null +++ b/config.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, TypeVar + +import attrs + +from omegaconf import DictConfig as LazyDict + +from .misc import Color + +T = TypeVar("T") + + +def _is_attrs_instance(obj: object) -> bool: + """ + Helper function to check if an object is an instance of an attrs-defined class. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs-defined class, False otherwise. + """ + return hasattr(obj, "__attrs_attrs__") + + +def make_freezable(cls: T) -> T: + """ + A decorator that adds the capability to freeze instances of an attrs-defined class. + + NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need + to hack on a "_is_frozen" attribute. + + This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. + Once an instance is frozen, its attributes cannot be changed. It also recursively freezes + any attrs-defined objects that are attributes of the class. + + Usage: + @make_freezable + @attrs.define(slots=False) + class MyClass: + attribute1: int + attribute2: str + + obj = MyClass(1, 'a') + obj.freeze() # Freeze the instance + obj.attribute1 = 2 # Raises AttributeError + + Args: + cls: The class to be decorated. + + Returns: + The decorated class with added freezing capability. + """ + + if not hasattr(cls, "__dict__"): + raise TypeError( + "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " + "class was defined with `@attrs.define(slots=False)`" + ) + + original_setattr = cls.__setattr__ + + def setattr_override(self, key, value) -> None: # noqa: ANN001 + """ + Override __setattr__ to allow modifications during initialization + and prevent modifications once the instance is frozen. + """ + if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": + raise AttributeError("Cannot modify frozen instance") + original_setattr(self, key, value) # type: ignore + + cls.__setattr__ = setattr_override # type: ignore + + def freeze(self: object) -> None: + """ + Freeze the instance and all its attrs-defined attributes. + """ + for _, value in attrs.asdict(self, recurse=False).items(): + if _is_attrs_instance(value) and hasattr(value, "freeze"): + value.freeze() + self._is_frozen = True # type: ignore + + cls.freeze = freeze # type: ignore + + return cls + + +def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: + """ + Recursively pretty prints attrs objects with color. + """ + + assert attrs.has(obj.__class__) + + lines: list[str] = [] + for attribute in attrs.fields(obj.__class__): + value = getattr(obj, attribute.name) + if attrs.has(value.__class__): + if use_color: + lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") + else: + lines.append(" " * indent + "* " + attribute.name + ":") + lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) + else: + if use_color: + lines.append( + " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) + ) + else: + lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) + return "\n".join(lines) + + +@make_freezable +@attrs.define(slots=False) +class JobConfig: + # Project name. + project: str = "" + # Experiment name. + group: str = "" + # Run/job name. + name: str = "" + + @property + def path(self) -> str: + return f"{self.project}/{self.group}/{self.name}" + + +@make_freezable +@attrs.define(slots=False) +class Config: + """Config for a job. + + See /README.md/Configuration System for more info. + """ + + # Model configs. + model: LazyDict + + # Training job configs. + job: JobConfig = attrs.field(factory=JobConfig) + + def to_dict(self) -> dict[str, Any]: + return attrs.asdict(self) + + def validate(self) -> None: + """Validate that the config has all required fields.""" + assert self.job.project != "", "Project name is required." + assert self.job.group != "", "Group name is required." + assert self.job.name != "", "Job name is required." diff --git a/config_base_conditioner.py b/config_base_conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..a414a3cc2d26b8b54428e3cbdc9f4ceaf2ae00b6 --- /dev/null +++ b/config_base_conditioner.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import attrs +import torch + +from .conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict + + +@attrs.define(slots=False) +class TextConfig: + obj: LazyDict = L(TextAttr)() # No arguments + dropout_rate: float = 0.2 + input_keys: List[str] = attrs.field(factory=lambda: ["t5_text_embeddings", "t5_text_mask"]) + + +class BooleanFlag(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None): + super().__init__() + self.output_key = output_key + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + del args, kwargs + key = self.output_key if self.output_key else self.input_key + return {key: self.flag} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) + return in_tensor + + +class ReMapkey(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None, dtype: Optional[str] = None): + super().__init__() + self.output_key = output_key + self.dtype = { + None: None, + "float": torch.float32, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float16": torch.float16, + "int": torch.int32, + "long": torch.int64, + }[dtype] + + def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: + key = self.output_key if self.output_key else self.input_key + if isinstance(element, torch.Tensor): + element = element.to(dtype=self.dtype) + return {key: element} + + +@attrs.define(slots=False) +class FPSConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `fps`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="fps", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "fps" + + +@attrs.define(slots=False) +class PaddingMaskConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `padding_mask`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="padding_mask", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "padding_mask" + + +@attrs.define(slots=False) +class ImageSizeConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `image_size`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="image_size", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "image_size" + + +@attrs.define(slots=False) +class NumFramesConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `num_frames`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="num_frames", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "num_frames" + + +@attrs.define(slots=False) +class VideoCondBoolConfig: + obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool") + dropout_rate: float = 0.2 + input_key: str = "fps" # This is a placeholder, we never use this value + # Config below are for long video generation only + + # Sample PPP... from IPPP... sequence + sample_tokens_start_from_p_or_i: bool = False + + +@attrs.define(slots=False) +class LatentConditionConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition" + + +@attrs.define(slots=False) +class LatentConditionSigmaConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition_sigma", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition_sigma" + + +BaseVideoConditionerConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), +) + +VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), +) + +VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), +) diff --git a/config_helper.py b/config_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ae20f39d57f08e1c28a2e32f66e804b26303a273 --- /dev/null +++ b/config_helper.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +import pkgutil +import sys +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass +from typing import Any, Dict, Optional + +import attr +import attrs +from hydra import compose, initialize +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig, OmegaConf + +from .log import log +from .config import Config +from .inference import * + + +def is_attrs_or_dataclass(obj) -> bool: + """ + Check if the object is an instance of an attrs class or a dataclass. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. + """ + return is_dataclass(obj) or attr.has(type(obj)) + + +def get_fields(obj): + """ + Get the fields of an attrs class or a dataclass. + + Args: + obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. + + Returns: + list: A list of field names. + + Raises: + ValueError: If the object is neither an attrs class nor a dataclass. + """ + if is_dataclass(obj): + return [field.name for field in dataclass_fields(obj)] + elif attr.has(type(obj)): + return [field.name for field in attr.fields(type(obj))] + else: + raise ValueError("The object is neither an attrs class nor a dataclass.") + + +def override(config: Config, overrides: Optional[list[str]] = None) -> Config: + """ + :param config: the instance of class `Config` (usually from `make_config`) + :param overrides: list of overrides for config + :return: the composed instance of class `Config` + """ + # Store the class of the config for reconstruction after overriding. + # config_class = type(config) + + # Convert Config object to a DictConfig object + config_dict = attrs.asdict(config) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + # Enforce "--" separator between the script arguments and overriding configs. + if overrides: + if overrides[0] != "--": + raise ValueError('Hydra config overrides must be separated with a "--" token.') + overrides = overrides[1:] + # Use Hydra to handle overrides + cs = ConfigStore.instance() + cs.store(name="config", node=config_omegaconf) + with initialize(version_base=None): + config_omegaconf = compose(config_name="config", overrides=overrides) + OmegaConf.resolve(config_omegaconf) + + def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: + """ + Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data + + Args: + ref_instance: The reference instance to determine the type and fields when needed + kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data + + Returns: + Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data + + Raises: + AssertionError: If the fields do not match or if extra keys are found. + Exception: If there is an error constructing the new instance. + """ + is_type = is_attrs_or_dataclass(ref_instance) + if not is_type: + return kwargs + else: + ref_fields = set(get_fields(ref_instance)) + assert isinstance(kwargs, dict) or isinstance( + kwargs, DictConfig + ), "kwargs must be a dictionary or a DictConfig" + keys = set(kwargs.keys()) + + # ref_fields must equal to or include all keys + extra_keys = keys - ref_fields + assert ref_fields == keys or keys.issubset( + ref_fields + ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" + + resolved_kwargs: Dict[str, Any] = {} + for f in keys: + resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) + try: + new_instance = type(ref_instance)(**resolved_kwargs) + except Exception as e: + log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") + log.error(e) + raise e + return new_instance + + config = config_from_dict(config, config_omegaconf) + + return config + + +def get_config_module(config_file: str) -> str: + if not config_file.endswith(".py"): + log.error("Config file cannot be specified as module.") + log.error("Please provide the path to the Python config file (relative to the Cosmos root).") + assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found." + # Convert to importable module format. + config_module = config_file.replace("/", ".").replace(".py", "") + return config_module + + +def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: + """ + Import all modules from the specified package path recursively. + + This function is typically used in conjunction with Hydra to ensure that all modules + within a specified package are imported, which is necessary for registering configurations. + + Example usage: + ```python + import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False) + ``` + + Args: + package_path (str): The dotted path to the package from which to import all modules. + reload (bool): Flag to determine whether to reload modules if they're already imported. + skip_underscore (bool): If True, skips importing modules that start with an underscore. + """ + return # we do not use this function + log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") + package = importlib.import_module(package_path) + package_directory = package.__path__ + + def import_modules_recursively(directory: str, prefix: str) -> None: + """ + Recursively imports or reloads all modules in the given directory. + + Args: + directory (str): The file system path to the current package directory. + prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config'). + """ + for _, module_name, is_pkg in pkgutil.iter_modules([directory]): + if skip_underscore and module_name.startswith("_"): + log.debug(f"Skipping module {module_name} as it starts with an underscore") + continue + + full_module_name = f"{prefix}.{module_name}" + log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") + + if full_module_name in sys.modules and reload: + importlib.reload(sys.modules[full_module_name]) + else: + importlib.import_module(full_module_name) + + if is_pkg: + sub_package_directory = os.path.join(directory, module_name) + import_modules_recursively(sub_package_directory, full_module_name) + + for directory in package_directory: + import_modules_recursively(directory, package_path) diff --git a/convert_pixtral_ckpt.py b/convert_pixtral_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..3b92524fc710764cfa1144c89f3efb99c8bc2633 --- /dev/null +++ b/convert_pixtral_ckpt.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading. + + Usage: + + PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py +""" + +import argparse +import json +import os +import shutil +from glob import glob + +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + + +def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str): + """ + Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint. + + Args: + checkpoint_dir (str): Path to the checkpoint directory + checkpoint_name (str): Name of the checkpoint + vit_type (str): Type of ViT used in the Pixtral model + + This function performs the following steps: + 0. Download the checkpoint from Hugging Face + 1. Loads the original Pixtral checkpoint + 2. Splits the checkpoint into vision encoder, projector, and LLM weights + 3. Reorganizes the weights to match the expected format + 4. Extracts and verifies the vision encoder configuration + 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer + 6. Optionally saves the converted checkpoint and configuration + """ + + save_dir = os.path.join(checkpoint_dir, checkpoint_name) + os.makedirs(save_dir, exist_ok=True) + # Save the converted checkpoint + save_path = os.path.join(save_dir, "model.pt") + if os.path.exists(save_path) and os.path.getsize(save_path) > 0: + print(f"Checkpoint {save_path} already exists and is not empty") + return + + pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") + os.makedirs(pixtral_ckpt_dir, exist_ok=True) + repo_id = "mistralai/Pixtral-12B-2409" + print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") + snapshot_download( + repo_id=repo_id, + allow_patterns=["params.json", "consolidated.safetensors"], + local_dir=pixtral_ckpt_dir, + local_dir_use_symlinks=False, + ) + orig_dtype = torch.get_default_dtype() + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + + # Load checkpoint file + ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) + assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" + ckpt_path = ckpt_files[0] + ckpt = load_file(ckpt_path) + + # Split checkpoint into weights of vision encoder, projector, and LLM + vit_key_prefix = "vision_encoder." + vit_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix): + vit_ckpt[key.lstrip(vit_key_prefix)] = value + + projector_key_prefix = "vision_language_adapter." + projector_ckpt = {} + substring_replacement_map = { + "w_in.": "projector.0.", + "w_out.": "projector.2.", + } + for key, value in ckpt.items(): + if key.startswith(projector_key_prefix): + key = key.lstrip(projector_key_prefix) + for old, new in substring_replacement_map.items(): + key = key.replace(old, new) + projector_ckpt[key] = value + + llm_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): + continue + llm_ckpt[key] = value + + vlm_ckpt = {} + for key, value in llm_ckpt.items(): + vlm_ckpt["model." + key] = value + for key, value in projector_ckpt.items(): + vlm_ckpt["mm_projector." + key] = value + for key, value in vit_ckpt.items(): + vlm_ckpt["vision_encoder." + key] = value + + # Load config + config_path = os.path.join(pixtral_ckpt_dir, "params.json") + with open(config_path, "r") as f: + pixtral_config = json.load(f) + + # Extract the vision encoder configuration + vision_encoder_config = { + "dim": pixtral_config["vision_encoder"]["hidden_size"], + "num_channels": pixtral_config["vision_encoder"]["num_channels"], + "image_size": pixtral_config["vision_encoder"]["image_size"], + "patch_size": pixtral_config["vision_encoder"]["patch_size"], + "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], + "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], + "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], + "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "norm_type": "rmsnorm", + "norm_eps": pixtral_config["norm_eps"], + "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], + } + # Configuration for the 400M ViT of Pixtral 12B VLM + vit_config = dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + # Compare the two configurations + for key, value in vit_config.items(): + assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" + + llm_config_keys = [ + "dim", + "n_layers", + "head_dim", + "hidden_dim", + "n_heads", + "n_kv_heads", + "rope_theta", + "norm_eps", + "vocab_size", + ] + assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" + replace_map = { + "hidden_dim": "ffn_hidden_size", + } + llm_config = {} + for k, v in pixtral_config.items(): + if k in llm_config_keys: + llm_config[replace_map.get(k, k)] = v + elif k == "vision_encoder": + llm_config["vision_encoder"] = vit_type + else: + raise ValueError(f"Unknown key: {k}") + + ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} + torch.save(ckpt_to_save, save_path) + print(f"Model saved to {save_path}") + + # Save config + config_path = os.path.join(save_dir, "config.json") + with open(config_path, "w") as f: + json.dump(llm_config, f) + + torch.set_default_dtype(orig_dtype) # Reset the default dtype + + # Remove the original Pixtral checkpoint + shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) + print(f"Removed {pixtral_ckpt_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy" + ) + parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory") + parser.add_argument( + "--checkpoint_name", + type=str, + default="Pixtral-12B", + help="Name of the checkpoint", + ) + parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model") + args = parser.parse_args() + convert_pixtral_checkpoint( + checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type + ) diff --git a/cosmos1/models/POST_TRAINING.md b/cosmos1/models/POST_TRAINING.md new file mode 100644 index 0000000000000000000000000000000000000000..66ab54711c5807a260cfc7aef85ca2d8fba63e98 --- /dev/null +++ b/cosmos1/models/POST_TRAINING.md @@ -0,0 +1,23 @@ +# Cosmos Post-training + +In the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai), we discuss several post-training examples of Cosmos pre-trained World Foundation Models (WFMs) for various Physical AI tasks, including + +- General Post-Training: Fine-tune the WFM to generate a target distribution of videos based on the custom dataset. The target distribution could include a specific camera spec or a specific domain such as a factory. +- Instruction Control: Post-trains models for robotic manipulation to predict videos based on textual instructions, enabling robots to visually simulate tasks like folding clothes or picking up objects. +- Action Control: Post-trains models for robotic manipulation to predict the next visual frame based on action vectors, simulating robotic tasks like object handling or movement planning. +- Camera Control: Adds camera pose conditioning to generate 3D-consistent video simulations from single images, enabling joystick-like navigation in virtual environments. +- Multi-View Generation: Post-trains models for autonomous vehicles to generate synchronized multi-view videos from text prompts, simulating driving scenarios with multiple camera perspectives. +- Multi-View Generation with Vehicle Trajectory Control: Extends multi-view generation by incorporating trajectory inputs, enabling precise simulation of driving environments for autonomous vehicles, adhering to specified paths. + +Except for the instruction control where the WFM is post-trained on a dataset of instruction-video pairs, all other cases require minor modifications of the network architectures. Post-training tasks will be supported by NeMo Framework. In this initial release, we provide post-training scripts for the general post-training of both diffusion and autorgressive WFMs. Scripts of the other post-training tasks will be provided in a future release. + +## Post-training Support Matrix + +| Post-training Task | Diffusion WFM | Autoregressive WFM | +|---------------------|---------------|--------------------| +| General post-training | [Supported](../models/diffusion/nemo/post_training/README.md) | [Supported](../models/autoregressive/nemo/post_training/README.md) | +| Instruction control | Coming soon | Coming soon | +| Action control | Coming soon | Coming soon | +| Camera control | Coming soon | Coming soon | +| Multi-view generation | Coming soon | Coming soon | +| Multi-view generation with vehicle trajectory control | Coming soon | Coming soon | diff --git a/cosmos1/models/autoregressive/README.md b/cosmos1/models/autoregressive/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f181e23ba285b5a673b4ca8c092936c32d4bbb74 --- /dev/null +++ b/cosmos1/models/autoregressive/README.md @@ -0,0 +1,427 @@ +# Cosmos Autoregressive-based World Foundation Models + +## Table of Contents +- [Getting Started](#getting-started) + - [Set Up Docker Environment](#set-up-docker-environment) + - [Download Checkpoints](#download-checkpoints) +- [Usage](#usage) + - [Model Types](#model-types) + - [Single and Batch Generation](#single-and-batch-generation) + - [Sample Commands](#sample-commands) + - [Base Models (4B/12B)](#base-basepy-4b-and-12b) + - [Video2World Models (5B/13B)](#video2world-video2worldpy-5b-and-13b) + - [Arguments](#arguments) + - [Common Parameters](#common-parameters) + - [Base Specific Parameters](#base-specific-parameters) + - [Video2World Specific Parameters](#video2world-specific-parameters) + - [Safety Features](#safety-features) + +This page details the steps for using the Cosmos autoregressive-based world foundation models. + +## Getting Started + +### Set Up Docker Environment + +Follow our [Installation Guide](../../../INSTALL.md) to set up the Docker environment. All commands on this page should be run inside Docker. + +### Download Checkpoints + +1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). + +2. Log in to Hugging Face with the access token: + +```bash +huggingface-cli login +``` + +3. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6): + +```bash +PYTHONPATH=$(pwd) python cosmos1/scripts/download_autoregressive.py --model_sizes 4B 5B 12B 13B +``` + +4. The downloaded files should be in the following structure: + +``` +checkpoints/ +├── Cosmos-1.0-Autoregressive-4B +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Autoregressive-5B-Video2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Autoregressive-12B +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Autoregressive-13B-Video2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Tokenizer-CV8x8x8 +│ ├── decoder.jit +│ ├── encoder.jit +│ └── mean_std.pt +├── Cosmos-1.0-Tokenizer-DV8x16x16 +│ ├── decoder.jit +│ └── encoder.jit +├── Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8 +│ ├── aux_vars.pt +│ └── model.pt +└── Cosmos-1.0-Guardrail + ├── aegis/ + ├── blocklist/ + ├── face_blur_filter/ + └── video_content_safety_filter/ +``` + +## Usage + + +### Model Types + +There are two model types available for autoregressive world generation: + +1. **Base**: Supports world generation from image/video input + +* Models: `Cosmos-1.0-Autoregressive-4B` and `Cosmos-1.0-Autoregressive-12B` +* Inference script: [base.py](/cosmos1/models/autoregressive/inference/base.py) + +2. **Video2World**: Supports world generation from image/video input and text input + +* Models: `Cosmos-1.0-Autoregressive-5B-Video2World` and `Cosmos-1.0-Autoregressive-13B-Video2World` +* Inference script: [video2world.py](/cosmos1/models/autoregressive/inference/video2world.py) + +Our models now support video extension up to 33 frames. Starting from either a single image or a 9-frame video input, they can generate the remaining frames to reach the 33-frame length (generating 32 or 24 frames, respectively). + +We have evaluated all eight possible configurations (4 models × 2 vision input types: image or video) using 100 test videos on physical AI topics. Below are the failure rates for each configuration: + +| Model | Image input | Video input (9 frames) | +|:------------------------------------------|:--------------:|:-------------------------:| +| Cosmos-1.0-Autoregressive-4B | 15% | 1% | +| Cosmos-1.0-Autoregressive-5B-Video2World | 7% | 2% | +| Cosmos-1.0-Autoregressive-12B | 2% | 1% | +| Cosmos-1.0-Autoregressive-13B-Video2World | 3% | 0% | + +We define failure cases as videos with severe distortions, such as: + +* Sudden appearance of large unexpected objects +* Video degrading to a single solid color + +Note that the following are not considered failures in our analysis: + +* Static video frames +* Minor object distortions or artifacts + +### Single and Batch Generation + +We support both single and batch video generation. + +For generating a single video, `base` mode requires the input argument `--input_image_or_video_path` (image/video input), while `video2world` mode requires both `--input_image_or_video_path` (image/video input) and `--prompt` (text input). + +Note that our model only works with 1024x640 resolution videos. If the input image/video is not in this resolution, it will be resized and cropped. + +For generating a batch of videos, both `base` and `video2world` require `--batch_input_path` (path to a JSONL file). For `base`, the JSONL file should contain one visual input per line in the following format, where each line must contain a "visual_input" field: + +```json +{"visual_input": "path/to/video1.mp4"} +{"visual_input": "path/to/video2.mp4"} +``` + +For `video2world`, each line in the JSONL file must contain both "prompt" and "visual_input" fields: + +```json +{"prompt": "prompt1", "visual_input": "path/to/video1.mp4"} +{"prompt": "prompt2", "visual_input": "path/to/video2.mp4"} +``` + +### Sample Commands + +There are two main demo scripts for autoregressive world generation: `base.py` and `video2world.py`. Below you will find sample commands for single and batch generation, as well as commands for running with low-memory GPUs using model offloading. We also provide a memory usage table comparing different offloading strategies to help with configuration. + +#### Base (base.py): 4B and 12B + +Generates world from image/video input. + +The `input_type` argument can be either `video` or `image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples. + +Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `image`. + +##### Single Generation + +```bash +# Example using 4B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --video_save_name=Cosmos-1.0-Autoregressive-4B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-4B \ + --top_p=0.8 \ + --temperature=1.0 + +# Example for low-memory GPUs using 4B model with model offloading +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --video_save_name=Cosmos-1.0-Autoregressive-4B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-4B \ + --top_p=0.8 \ + --temperature=1.0 \ + --offload_guardrail_models \ + --offload_diffusion_decoder \ + --offload_ar_model \ + --offload_tokenizer + +# Example using 12B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --video_save_name=Cosmos-1.0-Autoregressive-12B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-12B \ + --top_p=0.9 \ + --temperature=1.0 + +# Example for low-memory GPUs using 12B model with model offloading +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --video_save_name=Cosmos-1.0-Autoregressive-12B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-12B \ + --top_p=0.9 \ + --temperature=1.0 \ + --offload_guardrail_models \ + --offload_diffusion_decoder \ + --offload_ar_model \ + --offload_tokenizer +``` + +##### Batch Generation + +```bash +# Example using 4B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \ + --video_save_folder=outputs/Cosmos-1.0-Autoregressive-4B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-4B \ + --top_p=0.8 \ + --temperature=1.0 + +# Example using 12B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \ + --input_type=video \ + --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \ + --video_save_folder=outputs/Cosmos-1.0-Autoregressive-12B \ + --ar_model_dir=Cosmos-1.0-Autoregressive-12B \ + --top_p=0.9 \ + --temperature=1.0 +``` + +##### Example Output + +Here is an example output video generated using base.py with image input, using `Cosmos-1.0-Autoregressive-12B`: + + + +The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The image is from [BDD dataset](http://bdd-data.berkeley.edu/). + +Here is an example output video generated using base.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-12B`: + + + +The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`. + +##### Inference Time and GPU Memory Usage + +These numbers may vary based on system specifications and are provided for reference only. + +| Offloading Strategy | Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B | +|-------------|---------|---------| +| No offloading | 31.3 GB | 47.5 GB | +| Guardrails | 28.9 GB | 45.2 GB | +| Guardrails & Diffusion decoder | 28.5 GB | 43.1 GB | +| Guardrails & Diffusion decoder & Tokenizer | 27.3 GB | 42.9 GB | +| Guardrails & Diffusion decoder & Tokenizer & AR model | 18.7 GB | 27.4 GB | + +End-to-end inference runtime on one H100 without offloading and after model initialization: + +| Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B | +|---------|---------| +| ~62 seconds | ~119 seconds | + +#### Video2World (video2world.py): 5B and 13B + +Generates world from image/video and text input. + +The `input_type` argument can be either `text_and_video` or `text_and_image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples. + +Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `text_and_image`. + +##### Single Generation + +```bash +# Example using 5B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \ + --video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \ + --top_p=0.7 \ + --temperature=1.0 + +# Example for low-memory GPUs using 5B model with model offloading +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \ + --video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \ + --top_p=0.7 \ + --temperature=1.0 \ + --offload_guardrail_models \ + --offload_diffusion_decoder \ + --offload_ar_model \ + --offload_tokenizer \ + --offload_text_encoder_model + +# Example using 13B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \ + --video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \ + --top_p=0.8 \ + --temperature=1.0 \ + --offload_guardrail_models + +# Example for low-memory GPUs using 13B model with model offloading +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \ + --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \ + --video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \ + --top_p=0.8 \ + --temperature=1.0 \ + --offload_guardrail_models \ + --offload_diffusion_decoder \ + --offload_ar_model \ + --offload_tokenizer \ + --offload_text_encoder_model +``` + +##### Batch Generation + +```bash +# Example using 5B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \ + --video_save_folder=outputs/Cosmos-1.0-Autoregressive-5B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \ + --top_p=0.7 \ + --temperature=1.0 + +# Example using 13B model +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \ + --input_type=text_and_video \ + --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \ + --video_save_folder=outputs/Cosmos-1.0-Autoregressive-13B-Video2World \ + --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \ + --top_p=0.8 \ + --temperature=1.0 \ + --offload_guardrail_models +``` + +##### Example Output + +Here is an example output video generated using video2world.py with image input, using `Cosmos-1.0-Autoregressive-13B-Video2World`: + + + +The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The prompt for generating the video is: + +``` +A driving video captures a serene urban street scene on a sunny day. The camera is mounted on the dashboard of a moving vehicle, providing a first-person perspective as it travels down a two-lane road. The street is lined with parked cars on both sides, predominantly black and silver sedans and SUVs. The road is flanked by a mix of residential and commercial buildings, with a prominent red-brick building on the left side, featuring multiple windows and a flat roof. The sky is clear with a few scattered clouds, casting soft shadows on the street. Trees with lush green foliage line the right side of the road, providing a natural contrast to the urban environment. The camera remains steady, maintaining a consistent forward motion, suggesting a leisurely drive. Traffic is light, with a few vehicles moving in the opposite direction, including a black sedan and a yellow taxi. Street signs are visible, including a no-parking sign on the right. The overall atmosphere is calm and peaceful, with no pedestrians visible, emphasizing the focus on the drive and the surrounding urban landscape. +``` + +Here is an example output video generated using video2world.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-13B-Video2World`: + + + +The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`. The prompt for generating the video is: + +``` +A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions. +``` + +##### Inference Time and GPU Memory Usage + +These numbers may vary based on system specifications and are provided for reference only. + +| Offloading Strategy | Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World | +|-------------|---------|---------| +| No offloading | 66.2 GB | > 80 GB | +| Guardrails | 58.7 GB | 76.6 GB | +| Guardrails & T5 encoder | 41.3 GB | 58.0 GB | +| Guardrails & T5 encoder & Diffusion decoder | 29.0 GB | 46.9 GB | +| Guardrails & T5 encoder & Diffusion decoder & Tokenizer | 28.8 GB | 46.7 GB | +| Guardrails & T5 encoder & Diffusion decoder & Tokenizer & AR model | 21.1 GB | 30.9 GB | + +End-to-end inference runtime on one H100 with no offloading for 5B model and guardrail offloading for 13B, after model initialization: + +| Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World | +|---------|---------| +| ~73 seconds | ~150 seconds | + +### Arguments + +#### Common Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--checkpoint_dir` | Directory containing model weights | "checkpoints" | +| `--video_save_name` | Output video filename for single video generation | "output" | +| `--video_save_folder` | Folder where all output videos are stored | "outputs/" | +| `--input_image_or_video_path` | Input image or video path. Required for single video generation | None | +| `--batch_input_path` | Folder containing input images or videos. Required for batch video generation | None | +| `--num_input_frames` | Number of input frames to use for Video2World prediction | 9 | +| `--temperature` | Temperature used while sampling | 1.0 (recommend using values in sample commands provided) | +| `--top_p` | Top-p value for top-p sampling | 0.8 (recommend using values in sample commands provided) | +| `--seed` | Random seed | 0 | +| `--disable_diffusion_decoder` | When set to True, use discrete tokenizer to decode discrete tokens to video. Otherwise, use diffusion decoder to decode video | False | +| `--offload_guardrail_models` | Offload guardrail models after inference, used for low-memory GPUs | False | +| `--offload_diffusion_decoder` | Offload diffusion decoder after inference, used for low-memory GPUs | False | +| `--offload_ar_model` | Offload AR model after inference, used for low-memory GPUs | False | +| `--offload_prompt_upsampler` | Offload prompt upsampler after inference, used for low-memory GPUs | False | + +#### Base Specific Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" | +| `--input_type` | Input type, either `video` or `image` | "video" | + +#### Video2World Specific Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" | +| `--input_type` | Input type, either `text_and_video` or `text_and_image` | "text_and_video" | +| `--prompt` | Text prompt for single video generation. Required for single video generation | None | +| `--input_prompts_path` | Path to JSONL file for batch video generation. Required for batch video generation | None | +| `--offload_text_encoder_model` | Offload text encoder after inference, used for low-memory GPUs | False | + +### Safety Features + +The model uses a built-in safety guardrail system that cannot be disabled. Generating human faces is not allowed and will be blurred by the guardrail. + +For more information, check out the [Cosmos Guardrail Documentation](../guardrail/README.md). diff --git a/cosmos1/models/autoregressive/__init__.py b/cosmos1/models/autoregressive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/assets/nemo/finetuned_result.mp4 b/cosmos1/models/autoregressive/assets/nemo/finetuned_result.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7a50b83965dc9bc74a487630134328051cda17bf Binary files /dev/null and b/cosmos1/models/autoregressive/assets/nemo/finetuned_result.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..210c51cca9c08ce81299418698aee4fe40132b8c Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..818164dab96def4d1da2737dfa79d846be0bec75 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5bcb73f77ae32b1743a9baaf192eb0860dcf14ff Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..731783dc79a3de22eec5b9a5f84c4bae88cb9825 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5cbfce867cf588e11c0f9f5bc5cb4308a1c1329f Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e9e3fb9c8a1d464ef74938a089573f355a1a04ba Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e1620d029b89f4e9da1a626c69100715934c5201 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f493a11beb673cc12e43ac6b256a3153c4580662 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..37b294618103255cf10f8c1ac1bac3d08570c25c Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4 b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6317d1d5d99ec4304294a6f2a5b6a697765f6b31 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..3ef2e68b8475337d26b319593c195a60405e56c7 --- /dev/null +++ b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl @@ -0,0 +1,10 @@ +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4"} +{"visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4"} diff --git a/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6bb3a542a4ca2c58d01983a694b39a45557c1bf0 --- /dev/null +++ b/cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl @@ -0,0 +1,10 @@ +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/0.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/1.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/2.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/3.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/4.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/5.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/6.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/7.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/8.mp4"} +{"prompt": "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.", "visual_input": "cosmos1/models/autoregressive/assets/v1p0/batch_inputs/9.mp4"} diff --git a/cosmos1/models/autoregressive/assets/v1p0/input.jpg b/cosmos1/models/autoregressive/assets/v1p0/input.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1c166b2c79117739dfb8353daee59cfac0788689 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/input.jpg differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/input.mp4 b/cosmos1/models/autoregressive/assets/v1p0/input.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e9e3fb9c8a1d464ef74938a089573f355a1a04ba Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/input.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_12b.mp4 b/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_12b.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..46f459393348357bb46d835fb5c9cbb9092beb2d Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_12b.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_13b.mp4 b/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_13b.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fb35337add3d59a65b75414709661082efc8c043 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/output_from_image_input_13b.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_12b.mp4 b/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_12b.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..60b1270b3f42e0be0675d080eeae1b9a5a544a72 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_12b.mp4 differ diff --git a/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_13b.mp4 b/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_13b.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9a2c745ec229e35fdfbf75334885908d9131b716 Binary files /dev/null and b/cosmos1/models/autoregressive/assets/v1p0/output_from_video_input_13b.mp4 differ diff --git a/cosmos1/models/autoregressive/configs/__init__.py b/cosmos1/models/autoregressive/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/configs/base/__init__.py b/cosmos1/models/autoregressive/configs/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/configs/base/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/diffusion_decoder/__init__.py b/cosmos1/models/autoregressive/diffusion_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py b/cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..93503037bf8f3decde47e8d0250574f68a3b2e91 --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/config/base/conditioner.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +import torch + +from conditioner import BaseVideoCondition, GeneralConditioner +from config_base_conditioner import ( + FPSConfig, + ImageSizeConfig, + LatentConditionConfig, + LatentConditionSigmaConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, +) +from lazy_config_init import LazyCall as L +from lazy_config_init import LazyDict + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class VideoDiffusionDecoderConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoLatentDiffusionDecoderCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoLatentDiffusionDecoderCondition(**output) + + +VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + latent_condition=LatentConditionConfig(), + latent_condition_sigma=LatentConditionSigmaConfig(), +) diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py b/cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3522c452402335922c68077992ff0e92ad6030 --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos1.models.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs +from df_base_model import LatentDiffusionDecoderModelConfig +from df_config_registry import register_configs +from . import config +from config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "basic"}, + {"tokenizer": "tokenizer"}, + {"tokenizer_corruptor": None}, + {"latent_corruptor": None}, + {"pixel_corruptor": None}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config(model=LatentDiffusionDecoderModelConfig()) + + # Specifying values through instances of attrs + c.job.project = "cosmos_video4" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + # Call this function to register config groups for advanced overriding. + register_configs() + register_dd_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True) + import_all_modules_from_package("cosmos1.models.autoregressive.diffusion_decoder.config.inference", reload=True) + return c diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py b/cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..004be139e14faa505edc788ae5184e4d952588fd --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos1.models.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT +from lazy_config_init import LazyCall as L +from lazy_config_init import LazyDict + +num_frames = 57 +Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"}, + {"override /conditioner": "video_latent_diffusion_decoder_cond"}, + {"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"}, + "_self_", + ], + job=dict( + group="diffusion_deocder_FT_7Bv1_001", + name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token", + ), + model=dict( + diffusion_decoder_cond_sigma_low=0.0, + diffusion_decoder_cond_sigma_high=0.0, + diffusion_decoder_corrupt_prob=0.0, + condition_on_tokenizer_corruptor_token=True, + latent_shape=[ + 16, + num_frames, + 88, + 160, + ], + tokenizer_corruptor=dict( + pixel_chunk_duration=num_frames, + latent_chunk_duration=1 + (num_frames - 1) // 8, + ), + net=L(DiffusionDecoderGeneralDIT)( + diffusion_decoder_condition_on_sigma=False, + max_img_h=240, + max_img_w=240, + rope_h_extrapolation_ratio=1.5, + rope_w_extrapolation_ratio=1.5, + rope_t_extrapolation_ratio=1, + block_x_format="THWBD", + is_diffusion_decoder=True, + patch_spatial=2, + diffusion_decoder_condition_on_token=True, + diffusion_decoder_token_condition_voc_size=64000, + diffusion_decoder_token_condition_dim=32, + ), + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=num_frames, + ) + ), + conditioner=dict( + latent_condition=dict( + dropout_rate=0.2, + ) + ), + ), + ) +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"], + node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY, +) diff --git a/cosmos1/models/autoregressive/diffusion_decoder/config/registry.py b/cosmos1/models/autoregressive/diffusion_decoder/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a2dd89ea79b439cb4b4ef0fcdcffd3829aa7d2 --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/config/registry.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos1.models.autoregressive.diffusion_decoder.config.base.conditioner import ( + VideoLatentDiffusionDecoderConditionerConfig, +) +from discrete_video import DiscreteVideoFSQJITTokenizer +from pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from lazy_config_init import LazyCall as L + + +def get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution: str, + chunk_duration: int, + checkpoint_path: str, +): + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 16 + + return L(DiscreteVideoFSQJITTokenizer)( + enc_fp=checkpoint_path.replace(".jit", "encoder.jit"), + dec_fp=checkpoint_path.replace(".jit", "decoder.jit"), + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor, + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor], + ) + + +def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None): + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + latent_ch=16, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_video_tokenizer_comp8x8x8( + resolution="720", + chunk_duration=121, + checkpoint_path="checkpoints/Cosmos-1.0-Tokenizer-CV8x8x8/.jit", + ), + ) + + +def register_corruptor(cs): + cs.store( + group="tokenizer_corruptor", + package="model.tokenizer_corruptor", + name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224", + node=get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution="720", + chunk_duration=49, + checkpoint_path="checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/.jit", + ), + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="video_latent_diffusion_decoder_cond", + node=VideoLatentDiffusionDecoderConditionerConfig, + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_conditioner(cs) + register_corruptor(cs) + register_tokenizer(cs) diff --git a/cosmos1/models/autoregressive/diffusion_decoder/inference.py b/cosmos1/models/autoregressive/diffusion_decoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4116c4dfa0e6ce56ca2ab98edfa7410ebe401d --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/inference.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +from typing import List + +import torch + +from inference_config import DiffusionDecoderSamplingConfig +from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from cosmos1.models.autoregressive.diffusion_decoder.utils import linear_blend_video_list, split_with_overlap +from .log import log + + +def diffusion_decoder_process_tokens( + model: LatentDiffusionDecoderModel, + indices_tensor: List[torch.Tensor], + dd_sampling_config: DiffusionDecoderSamplingConfig = None, + original_video_example: torch.Tensor = None, + t5_emb_batch: List[torch.Tensor] = None, +): + _, T, H, W = original_video_example.shape + if dd_sampling_config is None: + dd_sampling_config = DiffusionDecoderSamplingConfig() + # indices_tensor is assumed to be a list of tensors with shape 1LHW + data_batch_list = [] + for sample_num, token_CTHW in enumerate(indices_tensor): + token_BCTHW = token_CTHW.unsqueeze(0).unsqueeze(1) + token_BCTHW = split_with_overlap( + token_BCTHW, + (dd_sampling_config.dd_train_num_video_frames - 1) // 8 + 1, + overlap=dd_sampling_config.overlap, + tobf16=False, + ) + data_batch_list.append( + { + "token_chunks": token_BCTHW, + "t5_text_embeddings": t5_emb_batch[sample_num].to(torch.bfloat16), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + # other conditions + "image_size": torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([dd_sampling_config.fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor( + [dd_sampling_config.dd_train_num_video_frames] * 1, dtype=torch.bfloat16 + ).cuda(), + "padding_mask": torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda(), + } + ) + + out_videos_batch = [] + + for idx, data_batch_template in enumerate(data_batch_list): + full_length_sample = [] + iterations = min(len(data_batch_template["token_chunks"]), dd_sampling_config.max_iter) + for iter in range(iterations): + gc.collect() + torch.cuda.empty_cache() + + data_batch = copy.deepcopy(data_batch_template) + data_batch["video"] = data_batch_template["token_chunks"][iter].cuda().to("cuda") + + log.debug(f"Run iter {iter} for video # {idx} at length {data_batch['video'].shape[2]}") + # org_video, + with torch.no_grad(): + samples_latent = model.generate_samples_from_batch( + data_batch, + guidance=dd_sampling_config.guidance, + sigma_min=dd_sampling_config.sigma_min, + state_shape=[ + dd_sampling_config.continuous_tokenizer_channel, + dd_sampling_config.continuous_tokenizer_spatial_compression_ratio, + H // 8, + W // 8, + ], + apply_corruptor=False, + return_recon_x=False, + # corrupt_sigma=dd_sampling_config.sigma, + preencode_condition=True, # We are using discrete model, so the input is already pre-encoded + num_steps=dd_sampling_config.num_steps, + ) + log.debug(f"Current sample shape {samples_latent.shape} for video # {idx} ") + full_length_sample.append(samples_latent.detach()) + + # Turn off because we remove CP + # distributed.barrier() + del data_batch + + torch.cuda.empty_cache() + + gc.collect() + torch.cuda.empty_cache() + + # Decode full-length samples and free GPU memory + full_length_sample_pixs = [model.decode(item).clamp(-1, 1).cpu() for item in full_length_sample] + torch.cuda.empty_cache() + + # Blend pixel samples + if len(full_length_sample_pixs) > 1: + full_length_sample_pixel_blend = linear_blend_video_list( + full_length_sample_pixs, dd_sampling_config.overlap + )[:, :, :T] + else: + full_length_sample_pixel_blend = full_length_sample_pixs[0][:, :, :T] + + # Batch size of full_length_sample_pixel_blend is always 1 + out_videos_batch.append((1 + full_length_sample_pixel_blend[0].cpu()) / 2) + return out_videos_batch diff --git a/cosmos1/models/autoregressive/diffusion_decoder/model.py b/cosmos1/models/autoregressive/diffusion_decoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4693a2868109e8a7b5e5598fbe76708fa25cd19e --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/model.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +import torch +from torch import Tensor + +from conditioner import BaseVideoCondition +from batch_ops import batch_mul +from res_sampler import COMMON_SOLVER_OPTIONS +from model_t2w import DiffusionT2WModel as VideoDiffusionModel +from lazy_config_init import instantiate as lazy_instantiate + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class LatentDiffusionDecoderModel(VideoDiffusionModel): + def __init__(self, config): + super().__init__(config) + """ + latent_corruptor: the corruption module is used to corrupt the latents. It add gaussian noise to the latents. + pixel_corruptor: the corruption module is used to corrupt the pixels. It apply gaussian blur kernel to pixels in a temporal consistent way. + tokenizer_corruptor: the corruption module is used to simulate tokenizer reconstruction errors. + + diffusion decoder noise augmentation pipeline for continuous token condition model: + condition: GT_video [T, H, W] + -> tokenizer_corruptor~(8x8x8) encode -> latent_corruptor -> tokenizer_corruptor~(8x8x8) decode + -> pixel corruptor + -> tokenizer~(1x8x8) encode -> condition [T, H/8, W/8] + GT: GT_video [T, H, W] -> tokenizer~(1x8x8) -> x_t [T, H/8, W/8]. + + diffusion decoder noise augmentation pipeline for discrete token condition model: + condition: GT_video [T, H, W] + -> pixel corruptor + -> discrete tokenizer encode -> condition [T, T/8, H/16, W/16] + GT: GT_video [T, H, W] -> tokenizer~(8x8x8) -> x_t [T, T/8, H/8, W/8]. + + """ + self.latent_corruptor = lazy_instantiate(config.latent_corruptor) + self.pixel_corruptor = lazy_instantiate(config.pixel_corruptor) + self.tokenizer_corruptor = lazy_instantiate(config.tokenizer_corruptor) + + if self.latent_corruptor: + self.latent_corruptor.to(**self.tensor_kwargs) + if self.pixel_corruptor: + self.pixel_corruptor.to(**self.tensor_kwargs) + + if self.tokenizer_corruptor: + if hasattr(self.tokenizer_corruptor, "reset_dtype"): + self.tokenizer_corruptor.reset_dtype() + else: + assert self.pixel_corruptor is not None + + self.diffusion_decoder_cond_sigma_low = config.diffusion_decoder_cond_sigma_low + self.diffusion_decoder_cond_sigma_high = config.diffusion_decoder_cond_sigma_high + self.diffusion_decoder_corrupt_prob = config.diffusion_decoder_corrupt_prob + if hasattr(config, "condition_on_tokenizer_corruptor_token"): + self.condition_on_tokenizer_corruptor_token = config.condition_on_tokenizer_corruptor_token + else: + self.condition_on_tokenizer_corruptor_token = False + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. + Another comes from a dataloader which we by default assumes as video_data for video model training. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + apply_corruptor: bool = True, + corrupt_sigma: float = 1.5, + preencode_condition: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + input_key = self.input_data_key # by default it is video key + # Latent state + raw_state = data_batch[input_key] + + if self.condition_on_tokenizer_corruptor_token: + if preencode_condition: + latent_condition = raw_state.to(torch.int32).contiguous() + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition[:, 0]) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.tokenizer_corruptor.encode(corrupted_pixel) + latent_condition = latent_condition[1] if isinstance(latent_condition, tuple) else latent_condition + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition) + latent_condition = latent_condition.unsqueeze(1) + else: + if preencode_condition: + latent_condition = raw_state + corrupted_pixel = self.decode(latent_condition) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.encode(corrupted_pixel).contiguous() + + sigma = ( + torch.rand((latent_condition.shape[0],)).to(**self.tensor_kwargs) * corrupt_sigma + ) # small value to indicate clean video + _, _, _, c_noise_cond = self.scaling(sigma=sigma) + if corrupt_sigma != self.diffusion_decoder_cond_sigma_low and self.diffusion_decoder_corrupt_prob > 0: + noise = batch_mul(sigma, torch.randn_like(latent_condition)) + latent_condition = latent_condition + noise + data_batch["latent_condition_sigma"] = batch_mul(torch.ones_like(latent_condition[:, 0:1, ::]), c_noise_cond) + data_batch["latent_condition"] = latent_condition + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn, corrupted_pixel + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + sigma_min: float = 0.02, + apply_corruptor: bool = False, + return_recon_x: bool = False, + corrupt_sigma: float = 0.01, + preencode_condition: bool = False, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + preencode_condition (bool): use pre-computed condition if true, save tokenizer's inference time memory/ + """ + if not preencode_condition: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = False + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + + x0_fn, recon_x = self.get_x0_fn_from_batch( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + apply_corruptor=apply_corruptor, + corrupt_sigma=corrupt_sigma, + preencode_condition=preencode_condition, + ) + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + x_sigma_max = ( + torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max + ) + + samples = self.sampler( + x0_fn, + x_sigma_max, + num_steps=num_steps, + sigma_min=sigma_min, + sigma_max=self.sde.sigma_max, + solver_option=solver_option, + ) + + if return_recon_x: + return samples, recon_x + else: + return samples diff --git a/cosmos1/models/autoregressive/diffusion_decoder/network.py b/cosmos1/models/autoregressive/diffusion_decoder/network.py new file mode 100644 index 0000000000000000000000000000000000000000..276aae166c348b727c5c395815daedf2b3c6f8a9 --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/network.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from blocks import PatchEmbed +from general_dit import GeneralDIT + + +class DiffusionDecoderGeneralDIT(GeneralDIT): + def __init__( + self, + *args, + is_diffusion_decoder: bool = True, + diffusion_decoder_condition_on_sigma: bool = False, + diffusion_decoder_condition_on_token: bool = False, + diffusion_decoder_token_condition_voc_size: int = 64000, + diffusion_decoder_token_condition_dim: int = 32, + **kwargs, + ): + # diffusion decoder setting + self.is_diffusion_decoder = is_diffusion_decoder + self.diffusion_decoder_condition_on_sigma = diffusion_decoder_condition_on_sigma + self.diffusion_decoder_condition_on_token = diffusion_decoder_condition_on_token + self.diffusion_decoder_token_condition_voc_size = diffusion_decoder_token_condition_voc_size + self.diffusion_decoder_token_condition_dim = diffusion_decoder_token_condition_dim + super().__init__(*args, **kwargs) + + def initialize_weights(self): + # Initialize transformer layers: + super().initialize_weights() + if self.diffusion_decoder_condition_on_token: + nn.init.constant_(self.token_embedder.weight, 0) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + is_diffusion_decoder, + diffusion_decoder_token_condition_dim, + diffusion_decoder_condition_on_sigma, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.is_diffusion_decoder, + self.diffusion_decoder_token_condition_dim, + self.diffusion_decoder_condition_on_sigma, + ) + in_channels = ( + in_channels + in_channels + if (is_diffusion_decoder and not self.diffusion_decoder_condition_on_token) + else in_channels + ) + in_channels = in_channels + 1 if diffusion_decoder_condition_on_sigma else in_channels + in_channels = ( + in_channels + self.diffusion_decoder_token_condition_dim + if self.diffusion_decoder_condition_on_token + else in_channels + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + if self.diffusion_decoder_condition_on_token: + self.token_embedder = nn.Embedding( + self.diffusion_decoder_token_condition_voc_size, self.diffusion_decoder_token_condition_dim + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.diffusion_decoder_condition_on_token: + latent_condition = self.token_embedder(latent_condition) + B, _, T, H, W, _ = latent_condition.shape + latent_condition = rearrange(latent_condition, "B 1 T H W D -> (B T) (1 D) H W") + + latent_condition = transforms.functional.resize( + latent_condition, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.BILINEAR + ) + latent_condition = rearrange(latent_condition, "(B T) D H W -> B D T H W ", B=B, T=T) + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition], dim=1) + if self.diffusion_decoder_condition_on_sigma: + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition_sigma], dim=1) + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb diff --git a/cosmos1/models/autoregressive/diffusion_decoder/utils.py b/cosmos1/models/autoregressive/diffusion_decoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c584c7c9a5e03bcb3b808d053f89e7c2aeaf9cf --- /dev/null +++ b/cosmos1/models/autoregressive/diffusion_decoder/utils.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True): + """ + Splits the video tensor into chunks of num_video_frames with a specified overlap. + + Args: + - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width]. + - num_video_frames (int): Number of frames per chunk. + - overlap (int): Number of overlapping frames between chunks. + + Returns: + - List of torch.Tensors: List of video chunks with overlap. + """ + # Get the dimensions of the input tensor + B, C, T, H, W = video_BCTHW.shape + + # Ensure overlap is less than num_video_frames + assert overlap < num_video_frames, "Overlap should be less than num_video_frames." + + # List to store the chunks + chunks = [] + + # Step size for the sliding window + step = num_video_frames - overlap + + # Loop through the time dimension (T) with the sliding window + for start in range(0, T - overlap, step): + end = start + num_video_frames + # Handle the case when the last chunk might go out of bounds + if end > T: + # Get the last available frame + num_padding_frames = end - T + chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect") + else: + # Regular case: no padding needed + chunk = video_BCTHW[:, :, start:end, :, :] + if tobf16: + chunks.append(chunk.to(torch.bfloat16)) + else: + chunks.append(chunk) + return chunks + + +def linear_blend_video_list(videos, D): + """ + Linearly blends a list of videos along the time dimension with overlap length D. + + Parameters: + - videos: list of video tensors, each of shape [b, c, t, h, w] + - D: int, overlap length + + Returns: + - output_video: blended video tensor of shape [b, c, L, h, w] + """ + assert len(videos) >= 2, "At least two videos are required." + b, c, t, h, w = videos[0].shape + N = len(videos) + + # Ensure all videos have the same shape + for video in videos: + assert video.shape == (b, c, t, h, w), "All videos must have the same shape." + + # Calculate total output length + L = N * t - D * (N - 1) + output_video = torch.zeros((b, c, L, h, w), device=videos[0].device) + + output_index = 0 # Current index in the output video + + for i in range(N): + if i == 0: + # Copy frames from the first video up to t - D + output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :] + output_index += t - D + else: + # Blend overlapping frames between videos[i-1] and videos[i] + blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device) + + for j in range(D): + w1 = 1 - blend_weights[j] + w2 = blend_weights[j] + frame_from_prev = videos[i - 1][:, :, t - D + j, :, :] + frame_from_curr = videos[i][:, :, j, :, :] + output_frame = w1 * frame_from_prev + w2 * frame_from_curr + output_video[:, :, output_index, :, :] = output_frame + output_index += 1 + + if i < N - 1: + # Copy non-overlapping frames from current video up to t - D + frames_to_copy = t - 2 * D + if frames_to_copy > 0: + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][ + :, :, D : t - D, :, : + ] + output_index += frames_to_copy + else: + # For the last video, copy frames from D to t + frames_to_copy = t - D + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :] + output_index += frames_to_copy + + return output_video diff --git a/cosmos1/models/autoregressive/inference/__init__.py b/cosmos1/models/autoregressive/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/inference/base.py b/cosmos1/models/autoregressive/inference/base.py new file mode 100644 index 0000000000000000000000000000000000000000..781de773c96c7556eedab0024dfce3c20a110b55 --- /dev/null +++ b/cosmos1/models/autoregressive/inference/base.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARBaseGenerationPipeline +from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from .log import log + + +def parse_args(): + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-1.0-Autoregressive-4B", + ) + parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) + args = parser.parse_args() + return args + + +def main(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple images/videos from input + - Generating videos from images/videos + - Saving the generated videos to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "base" # When the inference_type is "base", AR model does not take text as input, the world generation is purely based on the input video + sampling_config = validate_args(args, inference_type) + + # Initialize base generation model pipeline + pipeline = ARBaseGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + + for idx, input_filename in enumerate(input_videos): + inp_vid = input_videos[input_filename] + # Generate video + log.info(f"Run with image or video path: {input_filename}") + out_vid = pipeline.generate( + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked base generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + + imageio.mimsave(out_vid_path, out_vid, fps=25) + + log.info(f"Saved video to {out_vid_path}") + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos1/models/autoregressive/inference/video2world.py b/cosmos1/models/autoregressive/inference/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a547f8d62c3309353ebabbf9838d448a3b855c --- /dev/null +++ b/cosmos1/models/autoregressive/inference/video2world.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARVideo2WorldGenerationPipeline +from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from .log import log +from io import read_prompts_from_file + + +def parse_args(): + parser = argparse.ArgumentParser(description="Prompted video to world generation demo script") + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-1.0-Autoregressive-5B-Video2World", + ) + parser.add_argument( + "--input_type", + type=str, + default="text_and_video", + choices=["text_and_image", "text_and_video"], + help="Input types", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload T5 model after inference", + ) + args = parser.parse_args() + return args + + +def main(args): + """Run prompted video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "video2world" # When the inference_type is "video2world", AR model takes both text and video as input, the world generation is based on the input text prompt and video + sampling_config = validate_args(args, inference_type) + + # Initialize prompted base generation model pipeline + pipeline = ARVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + # Load input prompt(s) + if args.batch_input_path: + prompts_list = read_prompts_from_file(args.batch_input_path) + else: + prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}] + + # Iterate through prompts + for idx, prompt_entry in enumerate(prompts_list): + video_path = prompt_entry["visual_input"] + input_filename = os.path.basename(video_path) + + # Check if video exists in loaded videos + if input_filename not in input_videos: + log.critical(f"Input file {input_filename} not found, skipping prompt.") + continue + + inp_vid = input_videos[input_filename] + inp_prompt = prompt_entry["prompt"] + + # Generate video + log.info(f"Run with input: {prompt_entry}") + out_vid = pipeline.generate( + inp_prompt=inp_prompt, + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked video2world generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + imageio.mimsave(out_vid_path, out_vid, fps=25) + + log.info(f"Saved video to {out_vid_path}") + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos1/models/autoregressive/inference/world_generation_pipeline.py b/cosmos1/models/autoregressive/inference/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c34a41837ef3dafff4858190895533bb4a07a06e --- /dev/null +++ b/cosmos1/models/autoregressive/inference/world_generation_pipeline.py @@ -0,0 +1,912 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import List, Optional, Tuple + +from .log import log +import numpy as np +import torch +from einops import rearrange + +from model_config import create_video2world_model_config +from ar_config_tokenizer import TokenizerConfig +from inference_config import ( + DataShapeConfig, + DiffusionDecoderSamplingConfig, + InferenceConfig, + SamplingConfig, +) +from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens +from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from ar_model import AutoRegressiveModel +from cosmos1.models.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving +from base_world_generation_pipeline import BaseWorldGenerationPipeline +from inference_utils import ( + load_model_by_config, + load_network_model, + load_tokenizer_model, +) +from .misc import misc, Color, timer + + +def detect_model_size_from_ckpt_path(ckpt_path: str) -> str: + """Detect model size from checkpoint path. + + Args: + ckpt_path: Path to model checkpoint file + + Returns: + str: Model size ('4b', '5b', '12b', or '13b') + + Examples: + >>> detect_model_size_from_ckpt_path("model_4B.pt") + '4b' + """ + model_size = "4b" + if "4B" in ckpt_path: + model_size = "4b" + elif "5B" in ckpt_path: + model_size = "5b" + elif "12B" in ckpt_path: + model_size = "12b" + elif "13B" in ckpt_path: + model_size = "13b" + else: + log.warning(f"Could not detect model size from checkpoint path: {ckpt_path}") + return model_size + + +def create_inference_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + model_size: str = "4b", + batch_size: int = 1, + inference_type: str = "base", +) -> InferenceConfig: + """Create inference configuration for model. + + Args: + model_ckpt_path: Path to model checkpoint + tokenizer_ckpt_path: Path to tokenizer checkpoint + model_size: Size of model ('4b', '5b', '12b', '13b') + batch_size: Batch size for inference + inference_type: Type of inference ('base' or 'video2world') + + Returns: + InferenceConfig: Configuration object for inference + """ + model_size = model_size.lower() + # For inference config + kwargs = {} + if inference_type == "video2world": + kwargs.update( + dict( + insert_cross_attn=True, + insert_cross_attn_every_k_layers=1, + context_dim=1024, + training_type="text_to_video", + apply_abs_pos_emb=True, + ) + ) + if model_size == "5b": + model_size = "4b" # The base model (excluding the cross attention layers) is the 4B model + elif model_size == "13b": + model_size = "12b" # The base model (excluding the cross attention layers) is the 12B model + else: + raise ValueError(f"Unsupported model size for video2world inference_type: {model_size}") + else: + assert inference_type == "base", f"Unsupported inference_type: {inference_type}" + + model_config, tokenizer_config = create_video2world_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + rope_dim="3D", + add_special_tokens=False, + pixel_chunk_duration=33, + num_video_frames=33, + num_condition_latents_t=1, + batch_size=batch_size, + video_height=640, + video_width=1024, + **kwargs, + ) + + inference_config = InferenceConfig() + + inference_config.model_config = model_config + inference_config.tokenizer_config = tokenizer_config + + inference_config.data_shape_config = DataShapeConfig( + num_video_frames=model_config.num_video_frames, + height=model_config.video_height, + width=model_config.video_width, + latent_shape=model_config.video_latent_shape, + ) + inference_config.model_config.fuse_qkv = False + return inference_config + + +class ARBaseGenerationPipeline(BaseWorldGenerationPipeline): + """Base class for autoregressive world generation models. + + Handles the core functionality for generating videos using autoregressive models. + Provides configurable GPU memory management through model offloading and supports + different inference types for video generation. + + Attributes: + inference_config (InferenceConfig): Configuration for model inference + tokenizer_config (TokenizerConfig): Configuration for tokenizer + disable_diffusion_decoder (bool): Whether diffusion decoder is disabled + latent_shape (List[int]): Shape of video latents [T, H, W] + _supported_context_len (int): Supported context window length + latent_chunk_duration (int): Duration of latent chunks + pixel_chunk_duration (int): Duration of pixel chunks + diffusion_decoder_model (Optional[nn.Module]): The diffusion decoder model + """ + + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + enable_text_guardrail: bool = False, + enable_video_guardrail: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + ): + """Initialize the autoregressive world generation pipeline. + + Args: + inference_type: Type of world generation ('base' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the AR checkpoint to load + enable_text_guardrail: Whether to enable text content filtering + enable_video_guardrail: Whether to enable video content filtering + disable_diffusion_decoder: Whether to disable the diffusion decoder stage + offload_network: Whether to offload AR model from GPU after use + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + + Raises: + AssertionError: If inference_type is not 'base' or 'video2world' + """ + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + + # Create inference config + model_size = detect_model_size_from_ckpt_path(checkpoint_name) + model_ckpt_path = os.path.join(checkpoint_dir, checkpoint_name, "model.pt") + tokenizer_ckpt_path = os.path.join(checkpoint_dir, "Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit") + + inference_config: InferenceConfig = create_inference_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + inference_type=inference_type, + ) + + self.inference_config = inference_config + self.disable_diffusion_decoder = disable_diffusion_decoder + + if not disable_diffusion_decoder: + self.diffusion_decoder_ckpt_path = os.path.join( + checkpoint_dir, "Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8/model.pt" + ) + self.diffusion_decoder_config = "DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token" + self.diffusion_decoder_tokenizer_path = os.path.join(checkpoint_dir, "Cosmos-1.0-Tokenizer-CV8x8x8") + self.dd_sampling_config = DiffusionDecoderSamplingConfig() + aux_vars_path = os.path.join(os.path.dirname(self.diffusion_decoder_ckpt_path), "aux_vars.pt") + # We use a generic prompt when no text prompts are available for diffusion decoder. + # Generic prompt used - "high quality, 4k, high definition, smooth video" + aux_vars = torch.load(aux_vars_path, weights_only=True) + self.generic_prompt = dict() + self.generic_prompt["context"] = aux_vars["context"].cuda() + self.generic_prompt["context_mask"] = aux_vars["context_mask"].cuda() + + self.latent_shape = inference_config.data_shape_config.latent_shape # [L, 40, 64] + self._supported_context_len = _SUPPORTED_CONTEXT_LEN + self.tokenizer_config = inference_config.tokenizer_config + + self.offload_diffusion_decoder = offload_diffusion_decoder + self.diffusion_decoder_model = None + if not self.offload_diffusion_decoder and not disable_diffusion_decoder: + self._load_diffusion_decoder() + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + enable_text_guardrail=enable_text_guardrail, + enable_video_guardrail=enable_video_guardrail, + offload_guardrail_models=offload_guardrail_models, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=True, + ) + + def _load_model(self): + """Load and initialize the autoregressive model. + + Creates and configures the autoregressive model with appropriate settings. + """ + self.model = AutoRegressiveModel( + config=self.inference_config.model_config, + ) + + def _load_network(self): + """Load network weights for the autoregressive model.""" + self.model.load_ar_model(tokenizer_config=self.inference_config.tokenizer_config) + + def _load_tokenizer(self): + """Load and initialize the tokenizer model. + + Configures the tokenizer using settings from inference_config and + attaches it to the autoregressive model. + """ + self.model.load_tokenizer(tokenizer_config=self.inference_config.tokenizer_config) + + def _load_diffusion_decoder(self): + """Load and initialize the diffusion decoder model.""" + self.diffusion_decoder_model = load_model_by_config( + config_job_name=self.diffusion_decoder_config, + config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py", + model_class=LatentDiffusionDecoderModel, + ) + load_network_model(self.diffusion_decoder_model, self.diffusion_decoder_ckpt_path) + load_tokenizer_model(self.diffusion_decoder_model, self.diffusion_decoder_tokenizer_path) + + def _offload_diffusion_decoder(self): + """Offload diffusion decoder model from GPU memory.""" + if self.diffusion_decoder_model is not None: + del self.diffusion_decoder_model + self.diffusion_decoder_model = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model_with_offload( + self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the autoregressive model to generate video tokens. + + Takes input video frames and generates new video tokens using the autoregressive model. + Handles context frame selection and token generation. + + Args: + inp_vid (torch.Tensor): Input video tensor of shape + num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W). + seed (int): Random seed for generation + sampling_config (SamplingConfig): Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors, + List of token index tensors, + List of prompt embedding tensors + ) + """ + # Choosing the context length from list of available contexts + latent_context_t_size = 0 + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using input size of {context_used} frames") + + data_batch = {"video": inp_vid} + data_batch = misc.to(data_batch, "cuda") + + T, H, W = self.latent_shape + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="video", + num_chunks_to_generate=1, + seed=seed, + ) + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos_cur_batch, indices_tensor_cur_batch + + def _run_diffusion_decoder( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Process generated tokens through the diffusion decoder. + + Enhances video quality through diffusion-based decoding. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + out_videos_cur_batch_dd = diffusion_decoder_process_tokens( + model=self.diffusion_decoder_model, + indices_tensor=indices_tensor_cur_batch, + dd_sampling_config=self.dd_sampling_config, + original_video_example=out_videos_cur_batch[0], + t5_emb_batch=t5_emb_batch, + ) + return out_videos_cur_batch_dd + + def _run_diffusion_decoder_with_offload( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Run diffusion decoder with memory management. + + Loads decoder if needed, processes videos, and offloads decoder afterward + if configured in offload_diffusion_decoder. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + if self.offload_diffusion_decoder: + self._load_diffusion_decoder() + out_videos_cur_batch = self._run_diffusion_decoder(out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch) + if self.offload_diffusion_decoder: + self._offload_diffusion_decoder() + return out_videos_cur_batch + + def generate( + self, + inp_vid: torch.Tensor, + sampling_config: SamplingConfig, + num_input_frames: int = 9, + seed: int = 0, + ) -> np.ndarray | None: + """Generate a video continuation from input frames. + + Pipeline steps: + 1. Generates video tokens using autoregressive model + 2. Optionally enhances quality via diffusion decoder + 3. Applies safety checks if enabled + + Args: + inp_vid: Input video tensor of shape (batch_size, time, channels=3, height, width) + sampling_config: Parameters controlling the generation process + num_input_frames: Number of input frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch = self._run_model_with_offload( + inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch=[self.generic_prompt["context"]] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if self.enable_video_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video + + @torch.inference_mode() + def generate_partial_tokens_from_data_batch( + self, + data_batch: dict, + num_tokens_to_generate: int, + sampling_config: SamplingConfig, + tokenizer_config: TokenizerConfig, + latent_shape: list[int], + task_condition: str, + num_chunks_to_generate: int = 1, + seed: int = 0, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Generate video tokens from partial input tokens with conditioning. + + Handles token generation and decoding process: + 1. Processes input batch and applies conditioning + 2. Generates specified number of new tokens + 3. Decodes tokens to video frames + + Args: + data_batch: Dictionary containing input data including video and optional context + num_tokens_to_generate: Number of tokens to generate + sampling_config: Configuration for sampling parameters + tokenizer_config: Configuration for tokenizer, including video tokenizer settings + latent_shape: Shape of video latents [T, H, W] + task_condition: Type of generation task ('video' or 'text_and_video') + num_chunks_to_generate: Number of chunks to generate (default: 1) + seed: Random seed for generation (default: 0) + + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Input videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + log.debug(f"Starting generate_partial_tokens_from_data_batch with seed {seed}") + log.debug(f"Number of tokens to generate: {num_tokens_to_generate}") + log.debug(f"Latent shape: {latent_shape}") + + video_token_start = tokenizer_config.video_tokenizer.tokenizer_offset + video_vocab_size = tokenizer_config.video_tokenizer.vocab_size + video_token_end = video_token_start + video_vocab_size + + logit_clipping_range = [video_token_start, video_token_end] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + assert logit_clipping_range == [ + 0, + self.model.tokenizer.video_vocab_size, + ], f"logit_clipping_range {logit_clipping_range} is not supported for fast generate. Expected [0, {self.model.tokenizer.video_vocab_size}]" + + out_videos = {} + out_indices_tensors = {} + + # for text2world, we only add a token at the beginning of the video tokens, this applies to 5B and 13B models + if self.model.tokenizer.tokenizer_config.training_type == "text_to_video": + num_bov_tokens = 1 + num_eov_tokens = 0 + else: + num_eov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + num_bov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + + chunk_idx = 0 + out_videos[chunk_idx] = [] + out_indices_tensors[chunk_idx] = [] + + # get the context embedding and mask + context = data_batch.get("context", None) if task_condition != "video" else None + context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None + if context is not None: + context = misc.to(context, "cuda").detach().clone() + if context_mask is not None: + context_mask = misc.to(context_mask, "cuda").detach().clone() + + # get the video tokens + data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch) + data_tokens = misc.to(data_tokens, "cuda").detach().clone() + batch_size = data_tokens.shape[0] + + for sample_num in range(batch_size): + input_tokens = data_tokens[sample_num][0 : token_boundaries["video"][sample_num][1]] # [B, L] + input_tokens = [ + input_tokens[0 : -num_tokens_to_generate - num_eov_tokens].tolist() + ] # -1 is to exclude eov token + log.debug( + f"Run sampling. # input condition tokens: {len(input_tokens[0])}; # generate tokens: {num_tokens_to_generate + num_eov_tokens}; " + f"full length of the data tokens: {len(data_tokens[sample_num])}: {data_tokens[sample_num]}" + ) + video_start_boundary = token_boundaries["video"][sample_num][0] + num_bov_tokens + + video_decoded, indices_tensor = self.generate_video_from_tokens( + prompt_tokens=input_tokens, + latent_shape=latent_shape, + video_start_boundary=video_start_boundary, + max_gen_len=num_tokens_to_generate, + sampling_config=sampling_config, + logit_clipping_range=logit_clipping_range, + seed=seed, + context=context, + context_mask=context_mask, + ) # BCLHW, range [0, 1] + + # For the first chunk, we store the entire generated video + out_videos[chunk_idx].append(video_decoded[sample_num].detach().clone()) + out_indices_tensors[chunk_idx].append(indices_tensor[sample_num].detach().clone()) + + output_videos = [] + output_indice_tensors = [] + for sample_num in range(len(out_videos[0])): + tensors_to_concat = [out_videos[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate)] + concatenated = torch.cat(tensors_to_concat, dim=1) + output_videos.append(concatenated) + + indices_tensor_to_concat = [ + out_indices_tensors[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate) + ] + concatenated_indices_tensor = torch.cat(indices_tensor_to_concat, dim=1) # BLHW + output_indice_tensors.append(concatenated_indices_tensor) + + return output_videos, output_indice_tensors + + def generate_video_from_tokens( + self, + prompt_tokens: list[torch.Tensor], + latent_shape: list[int], + video_start_boundary: int, + max_gen_len: int, + sampling_config: SamplingConfig, + logit_clipping_range: list[int], + seed: int = 0, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Function to generate video from input tokens. These input tokens can be initial text tokens (in case of text to video), + or partial ground truth tokens. + + Handles the core token-to-video generation process: + 1. Generates new tokens using the autoregressive model + 2. Handles padding and token sequence completion + 3. Reshapes and processes generated tokens + 4. Decodes final tokens into video frames + + Args: + model (AutoRegressiveModel): LLama model instance + prompt_tokens (list): Prompt tokens used by the model + latent_shape (list): Shape of the video latents + video_start_boundary (int): Index where the video tokens start + max_gen_len (int): Maximum length of the tokens that needs to be generated + sampling_config (SamplingConfig): Config used by sampler during inference + logit_clipping_range (list): Range of indices in the logits to be clipped, e.g. [video_token_start, video_token_end] + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + total_seq_len = np.prod(latent_shape) + + assert not sampling_config.logprobs + + stop_tokens = self.model.tokenizer.stop_tokens + if self.offload_tokenizer: + self._offload_tokenizer() + if self.offload_network: + self._load_network() + + generation_tokens, _ = self.model.generate( + prompt_tokens=prompt_tokens, + temperature=sampling_config.temperature, + top_p=sampling_config.top_p, + echo=sampling_config.echo, + seed=seed, + context=context, + context_mask=context_mask, + max_gen_len=max_gen_len, + compile_sampling=sampling_config.compile_sampling, + compile_prefill=sampling_config.compile_prefill, + stop_tokens=stop_tokens, + verbose=True, + ) + generation_tokens = generation_tokens[:, video_start_boundary:] + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + if generation_tokens.shape[1] < total_seq_len: + log.warning( + f"Generated video tokens (shape:{generation_tokens.shape}) shorted than expected {total_seq_len}. Could be the model produce end token early. Repeat the last token to fill the sequence in order for decoding." + ) + padding_len = total_seq_len - generation_tokens.shape[1] + padding_tokens = generation_tokens[:, [-1]].repeat(1, padding_len) + generation_tokens = torch.cat([generation_tokens, padding_tokens], dim=1) + # Cast to LongTensor + indices_tensor = generation_tokens.long() + # First, we reshape the generated tokens into batch x time x height x width + indices_tensor = rearrange( + indices_tensor, + "B (T H W) -> B T H W", + T=latent_shape[0], + H=latent_shape[1], + W=latent_shape[2], + ) + log.debug(f"generated video tokens {len(generation_tokens[0])} -> reshape: {indices_tensor.shape}") + # If logit clipping range is specified, offset the generated indices by the logit_clipping_range[0] + # Video decoder always takes tokens in the range (0, N-1). So, this offset is needed. + if len(logit_clipping_range) > 0: + indices_tensor = indices_tensor - logit_clipping_range[0] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + # Now decode the video using tokenizer. + video_decoded = self.model.tokenizer.video_tokenizer.decode(indices_tensor.cuda()) + # Normalize decoded video from [-1, 1] to [0, 1], and clip value + video_decoded = (video_decoded * 0.5 + 0.5).clamp_(0, 1) + return video_decoded, indices_tensor + + +class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline): + """Video-to-world generation pipeline with text conditioning capabilities. + + Extends the base autoregressive generation pipeline by adding: + - Text prompt processing and embedding + - Text-conditioned video generation + - Additional safety checks for text input + - Memory management for text encoder model + + Enables generating video continuations that are guided by both + input video frames and text descriptions. + + Additional attributes compared to ARBaseGenerationPipeline: + offload_text_encoder_model (bool): Whether to offload text encoder from GPU after use + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_name: str, + inference_type: str = None, + enable_text_guardrail: bool = True, + enable_video_guardrail: bool = True, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + ): + """Initialize text-conditioned video generation pipeline. + + Args: + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the checkpoint to load + inference_type: Type of world generation workflow + enable_text_guardrail: Whether to enable content filtering for text (default: True) + enable_video_guardrail: Whether to enable content filtering for video (default: True) + disable_diffusion_decoder: Whether to disable diffusion decoder stage + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + offload_network: Whether to offload AR model from GPU + offload_tokenizer: Whether to offload tokenizer from GPU + offload_text_encoder_model: Whether to offload text encoder + """ + super().__init__( + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + inference_type=inference_type, + enable_text_guardrail=enable_text_guardrail, + enable_video_guardrail=enable_video_guardrail, + disable_diffusion_decoder=disable_diffusion_decoder, + offload_guardrail_models=offload_guardrail_models, + offload_diffusion_decoder=offload_diffusion_decoder, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + ) + self.offload_text_encoder_model = offload_text_encoder_model + if not self.offload_text_encoder_model: + self._load_text_encoder_model() + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Run model generation with memory management. + + Executes generation process and handles model offloading to manage GPU memory. + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + List of prompt embedding tensors + ) + """ + out_videos, indices_tensor, prompt_embedding = self._run_model( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos, indices_tensor, prompt_embedding + + def _run_model( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + """Run core model generation process. + + Handles text-conditioned video generation: + 1. Prepares data batch with text embeddings and video + 2. Determines appropriate context length + 3. Generates video tokens with text conditioning + 4. Processes output tensors + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + Text context tensor + ) + """ + data_batch = {} + data_batch["context"], data_batch["context_mask"] = prompt_embedding, prompt_mask + T, H, W = self.latent_shape + + if sampling_config is None: + sampling_config = self.sampling_config + if type(inp_vid) is list: + batch_size = len(inp_vid) + elif type(inp_vid) is torch.Tensor: + batch_size = 1 + data_batch["context"] = data_batch["context"].repeat(batch_size, 1, 1) + data_batch["context_mask"] = data_batch["context_mask"].repeat(batch_size, 1) + data_batch["context_mask"] = torch.ones_like(data_batch["context_mask"]).bool() + + latent_context_t_size = 0 + + # Choosing the context length from list of available contexts + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using context of {context_used} frames") + + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + data_batch["video"] = inp_vid + data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1) + + data_batch = misc.to(data_batch, "cuda") + + log.debug(f" num_tokens_to_generate: {num_gen_tokens}") + log.debug(f" sampling_config: {sampling_config}") + log.debug(f" tokenizer_config: {self.tokenizer_config}") + log.debug(f" latent_shape: {self.latent_shape}") + log.debug(f" latent_context_t_size: {latent_context_t_size}") + log.debug(f" seed: {seed}") + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="text_and_video", + seed=seed, + ) + return out_videos_cur_batch, indices_tensor_cur_batch, data_batch["context"] + + def generate( + self, + inp_prompt: str, + inp_vid: torch.Tensor, + num_input_frames: int = 9, + seed: int = 0, + sampling_config: SamplingConfig = None, + ) -> np.ndarray | None: + """Generate a video guided by text prompt and input frames. + + Pipeline steps: + 1. Validates text prompt safety if enabled + 2. Converts text to embeddings + 3. Generates video with text conditioning + 4. Enhances quality via diffusion decoder + 5. Applies video safety checks if enabled + + Args: + inp_prompt: Text prompt to guide the generation + inp_vid: Input video tensor with shape (batch_size, time, channels=3, height, width) + num_input_frames: Number of frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + if self.enable_text_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(inp_prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + log.info("Run text embedding on prompt") + prompt_embeddings, prompt_masks = self._run_text_embedding_on_prompt_with_offload([inp_prompt]) + prompt_embedding = prompt_embeddings[0] + prompt_mask = prompt_masks[0] + log.info("Finish text embedding on prompt") + + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch, prompt_embedding = self._run_model_with_offload( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, [prompt_embedding] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if self.enable_video_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video diff --git a/cosmos1/models/autoregressive/modules/__init__.py b/cosmos1/models/autoregressive/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/modules/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/nemo/__init__.py b/cosmos1/models/autoregressive/nemo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/autoregressive/nemo/cosmos.py b/cosmos1/models/autoregressive/nemo/cosmos.py new file mode 100644 index 0000000000000000000000000000000000000000..7b34fddaa3e4fce0a77f637b4c090d173aad303a --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/cosmos.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from nemo.collections.llm.gpt.model.llama import Llama3Config, LlamaModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io +from nemo.lightning.base import teardown +from torch import Tensor, nn + +from .log import log + + +class RotaryEmbedding3D(RotaryEmbedding): + """Rotary Embedding3D for Cosmos Language model. + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly + on the GPU. Defaults to False + latent_shape: The shape of the latents produced by the video after being tokenized + """ + + def __init__( + self, + seq_len: int, + kv_channels: int, + training_type: str = None, + rotary_base: int = 10000, + use_cpu_initialization: bool = False, + latent_shape=[5, 40, 64], + apply_yarn=False, + original_latent_shape=None, + beta_fast=32, + beta_slow=1, + scale=None, + max_position_embeddings=None, + original_max_position_embeddings=None, + extrapolation_factor=1, + attn_factor=1, + ) -> None: + super().__init__( + kv_channels=kv_channels, + rotary_base=rotary_base, + rotary_percent=1.0, + use_cpu_initialization=use_cpu_initialization, + ) + self.latent_shape = latent_shape + self.device = "cpu" if use_cpu_initialization else torch.cuda.current_device() + self.dim = kv_channels + self.rope_theta = rotary_base + self.apply_yarn = apply_yarn + self.original_latent_shape = original_latent_shape + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.scale = scale + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.attn_factor = attn_factor + dim_h = self.dim // 6 * 2 + dim_t = self.dim - 2 * dim_h + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.device) / dim_h + spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.device) / dim_t + temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) + spatial_inv_freq = spatial_inv_freq * scale_factors_spatial + scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) + temporal_inv_freq = temporal_inv_freq * scale_factors_temporal + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.temporal_inv_freq = temporal_inv_freq + max_seq_len_cached = max(self.latent_shape) + if self.apply_yarn and seq_len > max_seq_len_cached: + max_seq_len_cached = seq_len + self.max_seq_len_cached = max_seq_len_cached + self.freqs = self.get_freqs_non_repeated(self.max_seq_len_cached) + + def get_mscale(self, scale: float = 1.0) -> float: + """Get the magnitude scaling factor for YaRN.""" + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: + """Get the scale factors for YaRN.""" + # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called + # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. + high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len + low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len + # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear + # interpolation in between. + smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) + # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. + scale_factors = (1 - smooth_mask) / self.scale + smooth_mask + return scale_factors + + def get_freqs_non_repeated(self, max_seq_len_cached: int, offset: int = 0) -> Tensor: + dtype = self.spatial_inv_freq.dtype + device = self.spatial_inv_freq.device + + self.seq = (torch.arange(max_seq_len_cached, device=device, dtype=dtype) + offset).cuda() + + assert hasattr( + self, "latent_shape" + ), "Latent shape is not set. Please run set_latent_shape() method on rope embedding. " + T, H, W = self.latent_shape + half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq.cuda()) + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq.cuda()) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq.cuda()) + emb = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() + return emb + + @lru_cache(maxsize=32) + def forward(self, seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + if self.spatial_inv_freq.device.type == "cpu": + # move `inv_freq` to GPU once at the first micro-batch forward pass + self.spatial_inv_freq = self.spatial_inv_freq.to(device=torch.cuda.current_device()) + + max_seq_len_cached = self.max_seq_len_cached + if self.apply_yarn and seq_len > max_seq_len_cached: + max_seq_len_cached = seq_len + self.max_seq_len_cached = max_seq_len_cached + emb = self.get_freqs_non_repeated(self.max_seq_len_cached) + return emb + + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +@dataclass +class CosmosConfig(Llama3Config): + qk_layernorm: bool = True + rope_dim: str = "3D" + vocab_size: int = 64000 + activation_func = F.silu + + def configure_model(self, tokenizer) -> "MCoreGPTModel": + model = super().configure_model(tokenizer) + if self.rope_dim == "3D": + model.rotary_pos_emb = RotaryEmbedding3D( + seq_len=self.seq_length, + training_type=None, + kv_channels=self.kv_channels, + max_position_embeddings=self.seq_length, + original_max_position_embeddings=self.original_seq_len if hasattr(self, "original_seq_len") else None, + rotary_base=self.rotary_base, + apply_yarn=True if hasattr(self, "apply_yarn") else False, + scale=self.yarn_scale if hasattr(self, "yarn_scale") else None, + extrapolation_factor=1, + attn_factor=1, + beta_fast=self.yarn_beta_fast if hasattr(self, "yarn_beta_fast") else 32, + beta_slow=self.yarn_beta_slow if hasattr(self, "yarn_beta_slow") else 1, + latent_shape=[5, 40, 64], + original_latent_shape=self.original_latent_shape if hasattr(self, "original_latent_shape") else None, + ) + return model + + +@dataclass +class CosmosConfig4B(CosmosConfig): + rotary_base: int = 500_000 + seq_length: int = 15360 + num_layers: int = 16 + hidden_size: int = 4096 + ffn_hidden_size: int = 14336 + num_attention_heads: int = 32 + num_query_groups: int = 8 + layernorm_epsilon: float = 1e-5 + use_cpu_initialization: bool = True + make_vocab_size_divisible_by: int = 128 + kv_channels: int = 128 + + +@dataclass +class CosmosConfig12B(CosmosConfig): + rotary_base: int = 500_000 + seq_length: int = 15360 + num_layers: int = 40 + hidden_size: int = 5120 + ffn_hidden_size: int = 14336 + num_attention_heads: int = 32 + num_query_groups: int = 8 + layernorm_epsilon: float = 1e-5 + use_cpu_initialization: bool = True + make_vocab_size_divisible_by: int = 128 + kv_channels: int = 128 + original_latent_shape = [3, 40, 64] + apply_yarn: bool = True + yarn_beta_fast: int = 4 + yarn_beta_slow: int = 1 + yarn_scale: int = 2 + original_seq_len = 8192 + + +class CosmosModel(LlamaModel): + def __init__( + self, + config: Annotated[Optional[CosmosConfig], Config[CosmosConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or CosmosConfig4B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + self.config = config + + +@io.state_transform( + source_key=( + "model.layers.*.feed_forward.w1.weight", + "model.layers.*.feed_forward.w3.weight", + ), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _mlp_glu(ctx: io.TransformCTX, w1, w3): + return torch.cat((w1, w3), axis=0) + + +@io.state_transform( + source_key=( + "model.layers.*.attention.wq.weight", + "model.layers.*.attention.wk.weight", + "model.layers.*.attention.wv.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_cosmos(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.model_importer(CosmosModel, "pt") +class PTCosmosImporter(io.ModelConnector["PTCosmosModel", CosmosModel]): + def init(self) -> CosmosModel: + return CosmosModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + pt_model_path = str(self) + cosmos_model_state_dict = torch.load(pt_model_path, map_location="cpu") + for k, v in cosmos_model_state_dict.items(): + # convert to float 32 (for cpu conversion) (Original model is bf16) + cosmos_model_state_dict[k] = v.float() + + # Small wrapper since nemo calls source.state_dict() , to get state dict + class WrapperCosmos: + def __init__(self, model_state_dict): + self.model_state_dict = model_state_dict + + def state_dict(self): + return self.model_state_dict + + source = WrapperCosmos(cosmos_model_state_dict) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + log.info(f"Converted PT Cosmos model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.tok_embeddings.weight": "embedding.word_embeddings.weight", + "model.layers.*.attention.wo.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.attention.q_norm.weight": "decoder.layers.*.self_attention.q_layernorm.weight", + "model.layers.*.attention.k_norm.weight": "decoder.layers.*.self_attention.k_layernorm.weight", + "model.layers.*.attention_norm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.feed_forward.w2.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.ffn_norm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "model.output.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv_cosmos, _mlp_glu]) + + @property + def tokenizer(self): + return None + + @property + def config(self): + if "4B" in str(self) or "4b" in str(self): + return CosmosConfig4B() + elif "12B" in str(self) or "12b" in str(self): + return CosmosConfig12B() + else: + raise ValueError("Unable to infer model size from checkpoint") diff --git a/cosmos1/models/autoregressive/nemo/download_autoregressive_nemo.py b/cosmos1/models/autoregressive/nemo/download_autoregressive_nemo.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3327f8d53d1547ae60d1d28d08783f28bc05b0 --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/download_autoregressive_nemo.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from huggingface_hub import snapshot_download + + +def download_autregressive_nemo(): + """ + Downloads all Cosmos Autoregressive NeMo assets to HF_HOME directory. + Make sure to set HF_HOME to your desired path before running this function. + """ + snapshot_download("nvidia/Cosmos-1.0-Guardrail") + snapshot_download("nvidia/Cosmos-1.0-Tokenizer-DV8x16x16") + snapshot_download("nvidia/Cosmos-1.0-Autoregressive-4B", allow_patterns=["nemo/*"]) + snapshot_download("nvidia/Cosmos-1.0-Autoregressive-12B", allow_patterns=["nemo/*"]) + snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8") + + +def main(): + # Check if HF_HOME is set + hf_home = os.environ.get("HF_HOME") + if not hf_home: + raise EnvironmentError( + "The HF_HOME environment variable is not set. " + "Please set it to your desired path before running this script." + ) + + # Download Cosmos Autoregressive NeMo checkpoints + download_autregressive_nemo() + + +if __name__ == "__main__": + main() diff --git a/cosmos1/models/autoregressive/nemo/inference/README.md b/cosmos1/models/autoregressive/nemo/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c1227a2bfc8d65fa68ec051ddbd7af15e9c0827b --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/inference/README.md @@ -0,0 +1,167 @@ +# Cosmos Autoregressive-based World Foundation Models: NeMo Framework User Guide + +Learn how to [run inference](#run-inference) with Cosmos Autoregressive-based World Foundation Models (WFMs) using the [NVIDIA NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html) for your custom Physical AI tasks by following this guide. + +## Model Support Matrix + +The NeMo Framework supports the following Cosmos Autoregressive (AR) models. Review the available models and their compute requirements for post-training and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Inference | Multi-GPU Support | +|----------------------------------------------|------------------|------------------------------------------|---------| +| Cosmos-1.0-Autoregressive-4B | **Supported** | 1 NVIDIA GPU* | **Coming Soon** | +| Cosmos-1.0-Autoregressive-12B | **Supported** | 1 NVIDIA GPU* | **Coming Soon** | +| Cosmos-1.0-Autoregressive-5B-Video2World | **Coming Soon** | | | +| Cosmos-1.0-Autoregressive-13B-Video2World | **Coming Soon** | | | + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +## Post-Training Inference Support Matrix + +Cosmos Autoregressive-based WFMs can be post-trained for a variety of Physical AI tasks. Review the following table for a list of available Physical AI post-training tasks: + +| Post-training Task | Inference Support Status | +|-------------------------|--------------------| +| General post-training | **Supported** | +| Instruction control | **Coming Soon** | +| Action control | **Coming Soon** | +| Camera control | **Coming Soon** | +| Multi-view generation | **Coming Soon** | +| Multi-view generation with vehicle trajectory control | **Coming Soon** | + +## Prerequisites + +### 1. Review General Requirements + +- System Configuration + - **NVIDIA GPU and driver**: Ensure you have access to the minimum compute required to run the model(s), as listed in the model support matrix. + - **Containerization Platform**: We recommend using Docker with NVIDIA Container Runtime (alternatively, you may use NVIDIA enroot). +- Get your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens), which is required to obtain the Cosmos models for training and inference. +- Get your [Weights and Biases API Key](https://docs.wandb.ai/support/find_api_key/) for logging and tracking. + +### 2. Clone the Cosmos Repository + +```bash +git clone git@github.com:NVIDIA/Cosmos.git +``` + +### 3. Start the Container + +The [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) supports post-training and inference for Cosmos AR models. + +Run the following command to download and start the container: + ```bash + docker run --ipc=host -it --gpus=all \ + -v $PATH_TO_COSMOS_REPO:/workspace/Cosmos \ + nvcr.io/nvidia/nemo:cosmos.1.0 bash + ``` + +### 4. Download Checkpoints + +To help you get started, we've provided a [download script](../download_autoregressive_nemo.py) to get the Cosmos Autoregressive checkpoints from Hugging Face. These checkpoints are in the NeMo distributed checkpoint format required to run post-training and inference with NeMo Framework. + +1. Set the following environment variables: + ```bash + # You must set HF_HOME before running this script. + export HF_TOKEN="" + export HF_HOME="" + ``` +2. Run the following command to download the models: + ```bash + cd /workspace/Cosmos + python cosmos1/models/autoregressive/nemo/download_autoregressive_nemo.py + ``` + +## Run Inference + +Running inference with Cosmos AR models lets you predict video frames and generate a new video that continues the scene from a given input video. + +In this guide, we'll use this [example inference script](./general.py) to tokenize the input video into a sequence of tokens, which serve as prompts for the model. The model then generates new tokens representing the next set of frames. Finally, the new tokens are decoded back into video format. Only the last 9 frames of the input video are used to generate the next 24 frames. + +### Run the Inference Script with Base Model + +Complete the following steps to run inference on the 4B model. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Path to the the mp4 file (In git-lfs) + export INPUT_DATA=cosmos1/models/autoregressive/assets/v1p0/input.mp4 + ``` +2. Run the following command: + ```bash + cd /workspace/Cosmos + git lfs pull $INPUT_DATA + + NVTE_FLASH_ATTN=1 \ + NVTE_FUSED_ATTN=0 \ + NVTE_UNFUSED_ATTN=0 \ + torchrun --nproc-per-node 1 cosmos1/models/autoregressive/nemo/inference/general.py \ + --input_image_or_video_path $INPUT_DATA \ + --video_save_name "Cosmos-1.0-Autoregressive-4B.mp4" \ + --ar_model_dir nvidia/Cosmos-1.0-Autoregressive-4B + ``` + +### Run the Inference Script with Post-trained Model + +Complete the following steps to generate a new output video. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Inference with post-trained model. + export NEMO_CHECKPOINT=./logs/default/checkpoints/epoch\=0-step\=19 + + # Path to the the mp4 file (In git-lfs) + export INPUT_DATA=cosmos1/models/autoregressive/assets/v1p0/input.mp4 + + ``` +2. Run the following command: + ```bash + cd /workspace/Cosmos + git lfs pull $INPUT_DATA + + # change --ar_model_dir to a post-trained checkpoint under ./logs/default/checkpoints/ + NVTE_FLASH_ATTN=1 \ + NVTE_FUSED_ATTN=0 \ + NVTE_UNFUSED_ATTN=0 \ + torchrun --nproc-per-node 1 cosmos1/models/autoregressive/nemo/inference/general.py \ + --input_image_or_video_path $INPUT_DATA \ + --video_save_name "Cosmos-1.0-Autoregressive-4B.mp4" \ + --ar_model_dir "$NEMO_CHECKPOINT" + ``` + +#### Example Output + +The following output is an example video generated from the post-trained model using [`general.py`](./general.py): + + + +Generated videos are saved at the location configured in the `--video_save_name` parameter. + +The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`. + +> **Disclaimer**: +> The post-training example in this documentation is a demonstration of general post-training and not a guaranteed recipe for success. Post-training outcomes depend heavily on the quality and diversity of the dataset. To achieve good results, ensure your dataset is clean, well-structured, diverse, and properly labeled. Poorly prepared data can lead to issues like overfitting, bias, or poor performance. Carefully curate your dataset to reflect the desired use case for reliable results. + +### Configuration Options + +The following table details the parameters that can be modified for accelerated inference with NeMo. You can adjust these parameters to optimize performance based on your specific requirements + +| Parameter | Description | Default | +|--------------------------------|---------------------------------------------------------------------------------|---------| +| `--input_type` | The input type (image or video) | `video` | +| `--input_image_or_video_path` | Path to the input video to run inference | `cosmos1/models/autoregressive/assets/v1p0/input.mp4` | +| `--video_save_name` | Path to generated video | `./nemo_generated_video.mp4` | +| `--ar_model_dir` | Model name or path to model `nvidia/Cosmos-1.0-Autoregressive-4B` or `nvidia/Cosmos-1.0-Autoregressive-12B` | `nvidia/Cosmos-1.0-Autoregressive-4B` | +| `--encoder_path` | Path to encoder | `nvidia/Cosmos-1.0-Tokenizer-DV8x16x16` | +| `--decoder_path` | Path to the decoder | `nvidia/Cosmos-1.0-Tokenizer-DV8x16x1"` | +| `--guardrail_dir` | Path to guardrails | `nvidia/Cosmos-1.0-Guardrail` | +| `--top_p` | Top-p inference parameter | `0.9` | +| `--temperature` | Sampling temperature | `1` | +| `--disable_diffusion_decoder` | Disables running diffusion decoder on the generated result | `False` | diff --git a/cosmos1/models/autoregressive/nemo/inference/general.py b/cosmos1/models/autoregressive/nemo/inference/general.py new file mode 100644 index 0000000000000000000000000000000000000000..500f352d231d147f82717e5ae0edc6de49f86186 --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/inference/general.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from argparse import ArgumentParser +from typing import List + +import imageio +import nemo.lightning as nl +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import snapshot_download +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) +from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model +from nemo.lightning import io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir + +from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model +from discrete_video import DiscreteVideoFSQJITTokenizer +from cosmos1.models.autoregressive.utils.inference import load_vision_input +from .presets import presets as guardrail_presets +from .log import log + +torch._C._jit_set_texpr_fuser_enabled(False) + +TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] +NUM_CONTEXT_FRAMES = 33 +NUM_INPUT_FRAMES_VIDEO = 9 +LATENT_SHAPE = [5, 40, 64] +DATA_RESOLUTION = [640, 1024] + + +class CosmosMCoreTokenizerWrappper: + """ + A small dummy wrapper to pass into the text generation controller. + """ + + def __init__(self): + self.tokenizer = None + self.eod = -1 + self.vocab_size = 64000 + + def detokenize(self, tokens: List[int], remove_special_tokens: bool = False): + return tokens + + def tokenize(self, prompt: List[int]): + return prompt + + +def main(args): + num_input_frames = 1 if args.input_type == "image" else NUM_INPUT_FRAMES_VIDEO + + vision_input_dict = load_vision_input( + input_type=args.input_type, + batch_input_path=None, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=DATA_RESOLUTION, + num_input_frames=num_input_frames, + ) + + vision_input = list(vision_input_dict.values())[0].cuda() + + T, H, W = LATENT_SHAPE + latent_context_t_size = 1 if args.input_type == "image" else 2 + num_tokens_to_generate = int(np.prod([T - latent_context_t_size, H, W])) + + # Encode and Tokenize + if args.encoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": + args.encoder_path = os.path.join(snapshot_download(args.encoder_path), "encoder.jit") + if args.decoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": + args.decoder_path = os.path.join(snapshot_download(args.decoder_path), "decoder.jit") + video_tokenizer = DiscreteVideoFSQJITTokenizer( + enc_fp=args.encoder_path, + dec_fp=args.decoder_path, + name="discrete_video_fsq", + pixel_chunk_duration=NUM_CONTEXT_FRAMES, + latent_chunk_duration=T, + ).cuda() + + quantized_out, _ = video_tokenizer.encode(vision_input, pixel_chunk_duration=None) + indices = video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) + indices = rearrange(indices, "B T H W -> B (T H W)") + video_tokens = [indices[0][0:-num_tokens_to_generate].tolist()] + + # Load the nemo model + if args.ar_model_dir in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]: + args.ar_model_dir = os.path.join(snapshot_download(args.ar_model_dir, allow_patterns=["nemo/*"]), "nemo") + model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(args.ar_model_dir), subpath="model") + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + setup_optimizers=False, + store_optimizer_states=False, + ) + + trainer = nl.Trainer( + accelerator="gpu", + devices=1, + num_nodes=1, + strategy=strategy, + num_sanity_val_steps=0, + plugins=nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=False, + grad_reduce_in_fp32=False, + ), + ) + _setup_trainer_and_restore_model(path=args.ar_model_dir, trainer=trainer, model=model) + + inference_wrapped_model = model.get_inference_wrapper(torch.bfloat16, inference_batch_times_seqlen_threshold=1000) + + # Generate tokens + text_generation_controller = SimpleTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=CosmosMCoreTokenizerWrappper() + ) + + mcore_engine = MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=1) + + common_inference_params = CommonInferenceParams( + temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=num_tokens_to_generate + ) + + log.info(f"Running Inference to generate {num_tokens_to_generate} tokens. This will take some time. ") + results = mcore_engine.generate( + prompts=video_tokens, + add_BOS=False, + encoder_prompts=None, + common_inference_params=common_inference_params, + ) + + result = list(results)[0] + prompt_tokens = torch.tensor(result.prompt_tokens).cuda() + prompt_tokens[prompt_tokens == -1] = result.generated_tokens + + indices_tensor = prompt_tokens.unsqueeze(dim=0) + indices_tensor = rearrange( + indices_tensor, + "B (T H W) -> B T H W", + T=LATENT_SHAPE[0], + H=LATENT_SHAPE[1], + W=LATENT_SHAPE[2], + ) + + if torch.cuda.current_device() == 0: + # Decode the generated tokens + log.info("Running diffusion model on the generated result") + video_decoded = video_tokenizer.decode(indices_tensor.cuda()) + out_video = (video_decoded * 0.5 + 0.5).clamp_(0, 1) + + if not args.disable_diffusion_decoder: + del model + del inference_wrapped_model + del video_tokenizer + model = None + inference_wrapped_model = None + video_tokenizer = None + gc.collect() + torch.cuda.empty_cache() + + out_video = run_diffusion_decoder_model( + indices_tensor_cur_batch=[indices_tensor.squeeze()], out_videos_cur_batch=out_video + ) + + out_video = out_video[0].detach().clone() + output_video = (out_video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() + + if args.guardrail_dir: + log.info("Running guardrails on the generated video") + if args.guardrail_dir == "nvidia/Cosmos-1.0-Guardrail": + args.guardrail_dir = snapshot_download(args.guardrail_dir) + video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=args.guardrail_dir) + output_video = guardrail_presets.run_video_guardrail(output_video, video_guardrail) + if output_video is None: + raise ValueError("Guardrail blocked world generation.") + + # Write the video to disk + imageio.mimsave( + args.video_save_name, + output_video, + fps=25, # We use a fps of 25 just for visualization. + ) + + log.info(f"Saved to {args.video_save_name}") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) + parser.add_argument( + "--input_image_or_video_path", required=True, type=str, help="The path to the input video to run inference" + ) + parser.add_argument( + "--video_save_name", default="./nemo_generated_video.mp4", type=str, help="The path to generated video" + ) + parser.add_argument( + "--ar_model_dir", + default="nvidia/Cosmos-1.0-Autoregressive-4B", + type=str, + help="The path to the nemo autoregressive model", + ) + parser.add_argument( + "--encoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to encoder" + ) + parser.add_argument( + "--decoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to the decoder" + ) + parser.add_argument( + "--guardrail_dir", default="nvidia/Cosmos-1.0-Guardrail", type=str, help="The path to the guardrails" + ) + parser.add_argument("--top_p", default=0.8, type=float, help="The top_p inference parameter ") + parser.add_argument("--temperature", default=1, type=int, help="Sampling temperature") + parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") + + args = parser.parse_args() + + main(args) diff --git a/cosmos1/models/autoregressive/nemo/post_training/README.md b/cosmos1/models/autoregressive/nemo/post_training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ef0f7e18b530971bf892f4808abc9e1cbe2cb615 --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/post_training/README.md @@ -0,0 +1,183 @@ +# Cosmos Autoregressive-based World Foundation Models: NeMo Framework User Guide + +Learn how to [post-train](#post-train) Cosmos Autoregressive-based World Foundation Models (WFMs) using the [NVIDIA NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html) for your custom Physical AI tasks by following this guide. + +## Model Support Matrix + +The NeMo Framework supports the following Cosmos Autoregressive (AR) models. Review the available models and their compute requirements for post-training and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Post-Training | +|-------------------------|----------------------------|-------------------------------------------| +| Cosmos-1.0-Autoregressive-4B | **Supported** | 2 NVIDIA GPUs* | +| Cosmos-1.0-Autoregressive-12B | **Supported** | 8 NVIDIA GPUs* | +| Cosmos-1.0-Autoregressive-5B-Video2World | **Coming Soon** | | +| Cosmos-1.0-Autoregressive-13B-Video2World | **Coming Soon** | | + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +## Post-Training Support Matrix + +Cosmos Autoregressive-based WFMs can be post-trained for a variety of Physical AI tasks. Review the following table for a list of available Physical AI post-training tasks: + +| Post-training Task | Support Status | +|-------------------------|--------------------| +| General post-training | **Supported** | +| Instruction control | **Coming Soon** | +| Action control | **Coming Soon** | +| Camera control | **Coming Soon** | +| Multi-view generation | **Coming Soon** | +| Multi-view generation with vehicle trajectory control | **Coming Soon** | + +## Prerequisites + +### 1. Review General Requirements + +- System Configuration + - **NVIDIA GPU and driver**: Ensure you have access to the minimum compute required to run the model(s), as listed in the model support matrix. + - **Containerization Platform**: We recommend using Docker with NVIDIA Container Runtime (alternatively, you may use NVIDIA enroot). +- Get your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens), which is required to obtain the Cosmos models for training and inference. +- Get your [Weights and Biases API Key](https://docs.wandb.ai/support/find_api_key/) for logging and tracking. + +### 2. Clone the Cosmos Repository + +```bash +git clone git@github.com:NVIDIA/Cosmos.git +``` + +### 3. Start the Container + +The [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) supports post-training and inference for Cosmos AR models. + +Run the following command to download and start the container: + ```bash + docker run --ipc=host -it --gpus=all \ + -v $PATH_TO_COSMOS_REPO:/workspace/Cosmos \ + nvcr.io/nvidia/nemo:cosmos.1.0 bash + ``` + +### 4. Download Checkpoints + +To help you get started, we've provided a [download script](../download_autoregressive_nemo.py) to get the Cosmos Autoregressive checkpoints from Hugging Face. These checkpoints are in the NeMo distributed checkpoint format required to run post-training and inference with NeMo Framework. + +1. Set the following environment variables: + ```bash + # You must set HF_HOME before running this script. + export HF_TOKEN="" + export HF_HOME="" + ``` +2. Run the following command to download the models: + ```bash + cd /workspace/Cosmos + python cosmos1/models/autoregressive/nemo/download_autoregressive_nemo.py + ``` + +## Post-train + +Post-training a Cosmos Autoregressive-based WFM enables you to train the model to generate videos using frame predictions that are more specific to your Physical AI use case. + +For example, if you want to generate action sequences for a specific robot, you can post-train the model to generate videos that are more aligned with typical actions/outcomes for that robot. + +There are 3 steps to post-training: preparing a dataset, preprocessing the data, and post-training the model. + +### 1. Prepare a Dataset + +The first step is to prepare a dataset. Post-training a Cosmos-1.0-Autoregressive-4B model enables you to get better video-frame predictions for your specific use case. + +You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. In this guide, we'll use the sample videos located in the `cosmos1/models/autoregressive/assets/v1p0/batch_inputs` directory. + +### 2. Preprocess Data + +The second step is to preprocess the data to create an [indexed dataset](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets). The `IndexedDataset` class is the lowest-level data interface in Megatron Core and creates a `.bin` and `.idx` file. + +Before proceeding, ensure all videos are in **RGB format**. Complete the following steps to preprocess the data. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Path to Raw mp4 videos. + export RAW_DATA="cosmos1/models/autoregressive/assets/v1p0/batch_inputs" + + # Path to Processed Dataset. + export OUTPUT_PREFIX="./indexed_videos" + + ``` +2. Run the following command to preprocess the data: + ```bash + cd /workspace/Cosmos + git lfs pull --include=$RAW_DATA + + python cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py \ + --input_videos_dir $RAW_DATA \ + --output_prefix $OUTPUT_PREFIX + ``` + +Executing the [data preprocessing script](./prepare_dataset.py) generates the following files for each video: + +- **`[i].idx` File**: This file contains metadata at the dataset level: + - **Index Header**: Ensures backward compatibility. + - **Index Version**: Maintains backward compatibility. + - **Data Type Code**: Numeric code indicating the data type used in the data file. + - **Sequence Count**: Total number of sequences in the dataset. + - **Document Count**: Total number of documents in the dataset. + +- **`[i].bin` File**: This file includes metadata at the document and sequence levels: + - **Elements per Sequence**: Number of elements in each sequence. + - **Byte Offset per Sequence**: Pointer indicating the start of each sequence. + - **Sequence Index Range**: Consecutive index range `[...)` for each document. + +### 3. Post-train the Model + +The third step is to post-train the model. This step uses NeMo Framework's data and model parallelism capabilities to train the model on the post-training samples. This is accomplished by utilizing Tensor Parallelism. + +- **Tensor Parallelism**: Spreads the parameter tensor of individual layers across GPUs. + +#### Run the Post-training Script + +Complete the following steps to post-train the Cosmos-1.0-Autoregressive-4B model. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Number of GPU devices available for post-training. At least 2 for 4B and 8 for 12B. + export NUM_DEVICES=2 + + # Optionally, you can monitor training progress with Weights and Biases (wandb). + export WANDB_API_KEY="" + export WANDB_PROJECT_NAME="cosmos-autoregressive-nemo-finetuning" + export WANDB_RUN_ID="cosmos_autoregressive_4b_finetune" + ``` +2. Run the following command for Cosmos-1.0-Autoregressive-4B post-training: + ```bash + torchrun --nproc-per-node $NUM_DEVICES cosmos1/models/autoregressive/nemo/post_training/general.py \ + --data_path $OUTPUT_PREFIX \ + --split_string 4,1,1 \ + --log_dir ./logs \ + --max_steps 20 --save_every_n_steps 10 \ + --tensor_model_parallel_size $NUM_DEVICES \ + --model_path nvidia/Cosmos-1.0-Autoregressive-4B + ``` + +3. You can now run inference with your post-trained model using the instructions [here](../inference/README.md#run-the-inference-script-with-post-trained-model). + +#### Configuration Options + +Before getting started, review the following parameters that made available to the script. You can adjust these parameters to optimize performance based on your specific requirements. + +| Parameter | Description | Default | +|---|---|---| +| `--data_path` | Specifies the location of your preprocessed dataset. Ensure this path points to the directory containing your `.bin` and `.idx` files. | `/path/to/data` | +| `--model_path` | Specifies the directory to the cosmos model to run post-training on. | `nvidia/Cosmos-1.0-Autoregressive-4B` | +| `--index_mapping_dir` | Specifies the directory to store the indexed dataset. | `./index_mapping` | +| `--log_dir` | Specifies the directory to store the logs and checkpoints. | `./log_dir` | +| `--split_string` | Specifies the data split ratios for training, validation, and testing. | `4,1,1` | +| `--tensor_model_parallel_size` | Controls the number of GPUs used for model parallelism. Increase this number to scale up, ensuring your hardware can support the additional load. | `2` | +| `--max_steps` | Defines the total number of training steps. Adjust based on training duration and storage capacity. | `100` | +| `--save_every_n_steps` | Defines how often checkpoints are saved. Adjust based on training duration and storage capacity. | `10` | +| `--global_batch_size` | Tweaks to optimize memory usage and training speed. Larger batch sizes may improve convergence but require more memory. | `2` | +| `--micro_batch_size` | Tweaks to optimize memory usage and training speed. Larger batch sizes may improve convergence but require more memory. | `1` | +| `--lr` | Sets the learning rate. A common starting point is `5e-5`, but this can be adjusted based on model performance and convergence behavior. | `5e-5` | +| `--max_epochs` | The maximum number of epochs to run during post-training | `10` | diff --git a/cosmos1/models/autoregressive/nemo/post_training/general.py b/cosmos1/models/autoregressive/nemo/post_training/general.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f56b924d713f07770b10f5c3f3db7ab8049183 --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/post_training/general.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +import torch +from huggingface_hub import snapshot_download +from lightning.pytorch.loggers import WandbLogger +from megatron.core.optimizer import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback +from nemo.lightning.pytorch.strategies.utils import RestoreConfig + +from cosmos1.models.autoregressive.nemo.cosmos import CosmosConfig4B, CosmosConfig12B, CosmosModel + + +def main(args): + if "4B" in args.model_path: + config = CosmosConfig4B() + elif "12B" in args.model_path: + config = CosmosConfig12B() + else: + raise NotImplementedError + + if args.model_path in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]: + args.model_path = os.path.join(snapshot_download(args.model_path, allow_patterns=["nemo/*"]), "nemo") + + model = CosmosModel(config) + + data_module = llm.PreTrainingDataModule( + paths=[args.data_path], + seq_length=12800, + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + tokenizer=None, + split=args.split_string, + num_workers=1, + index_mapping_dir=args.index_mapping_dir, + ) + + # Finetune is the same as train (Except train gives the option to set tokenizer to None) + # So we use it since in this case we dont store a tokenizer with the model + llm.api.train( + model=model, + data=data_module, + trainer=nl.Trainer( + devices=args.tensor_model_parallel_size, + num_nodes=1, + max_steps=args.max_steps, + accelerator="gpu", + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=args.tensor_model_parallel_size, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.bfloat16, + ), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + num_sanity_val_steps=0, + limit_val_batches=0, + max_epochs=args.max_epochs, + log_every_n_steps=1, + callbacks=[ + ModelCheckpoint( + monitor="reduced_train_loss", + filename="{epoch}-{step}", + every_n_train_steps=args.save_every_n_steps, + save_top_k=2, + ), + PreemptionCallback(), + ], + ), + log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None), log_dir=args.log_dir), + optim=nl.MegatronOptimizerModule( + config=OptimizerConfig( + lr=args.lr, + bf16=True, + params_dtype=torch.bfloat16, + use_distributed_optimizer=False, + ) + ), + tokenizer=None, + resume=nl.AutoResume( + restore_config=RestoreConfig(path=args.model_path), + resume_if_exists=True, + resume_ignore_no_checkpoint=False, + resume_past_end=True, + ), + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--data_path", required=True, type=str, help="The path to the .bin .idx files") + parser.add_argument( + "--model_path", default="nvidia/Cosmos-1.0-Autoregressive-4B", type=str, help="The path to the nemo model" + ) + parser.add_argument( + "--index_mapping_dir", default="./index_mapping", type=str, help="The directory to store mapped indices" + ) + parser.add_argument("--log_dir", default="./log_dir", type=str, help="The path to the logs") + parser.add_argument("--split_string", default="98,1,1", type=str, help="The train/test/validation split") + parser.add_argument("--tensor_model_parallel_size", default=2, type=int, help="Tensor model parallel size") + parser.add_argument("--max_steps", default=100, type=int, help="The max number of steps to run finetuning") + parser.add_argument("--save_every_n_steps", default=100, type=int, help="How often to save a checkpoint") + parser.add_argument("--global_batch_size", default=2, type=int, help="The global batch size") + parser.add_argument( + "--micro_batch_size", default=1, type=int, help="The micro batch size if using pipeline parallel" + ) + parser.add_argument("--lr", default=5e-5, type=float, help="The learning rate") + parser.add_argument("--max_epochs", default=10, type=int, help="Max number of epochs") + + args = parser.parse_args() + + main(args) diff --git a/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bc31c15435ec0aa9b4252d73b48199becfd33e9f --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/post_training/prepare_dataset.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from glob import glob + +import torch +from einops import rearrange +from huggingface_hub import snapshot_download +from nemo.collections.nlp.data.language_modeling.megatron import indexed_dataset + +from cosmos1.models.autoregressive.nemo.utils import read_input_videos +from discrete_video import DiscreteVideoFSQJITTokenizer +from .log import log + +TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] +DATA_RESOLUTION_SUPPORTED = [640, 1024] +NUM_CONTEXT_FRAMES = 33 + + +def main(args): + if args.encoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": + args.encoder_path = os.path.join(snapshot_download(args.encoder_path), "encoder.jit") + if args.decoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": + args.decoder_path = os.path.join(snapshot_download(args.decoder_path), "decoder.jit") + video_tokenizer = DiscreteVideoFSQJITTokenizer( + enc_fp=args.encoder_path, + dec_fp=args.decoder_path, + name="discrete_video_fsq", + pixel_chunk_duration=NUM_CONTEXT_FRAMES, + ).cuda() + + builders = {} + key = "text" + builders[key] = indexed_dataset.make_builder( + f"{args.output_prefix}.bin", + impl="mmap", + chunk_size=64, + pad_id=0, + retrieval_db=None, + vocab_size=64000, + stride=64, + ) + + filepaths_final = glob(f"{args.input_videos_dir}/*.mp4") + + for filepath in filepaths_final: + input_video = read_input_videos(filepath).cuda() + batch_size, channels, frames, height, width = input_video.shape + latent_shape = ( + (frames - 1) // TOKENIZER_COMPRESSION_FACTOR[0] + 1, + height // TOKENIZER_COMPRESSION_FACTOR[1], + width // TOKENIZER_COMPRESSION_FACTOR[2], + ) + T, H, W = latent_shape + video_tokenizer.latent_chunk_duration = T + quantized_out, _ = video_tokenizer.encode(input_video, pixel_chunk_duration=None) + indices = video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) + indices = rearrange(indices, "B T H W -> (B T H W)").detach().cpu() + builders[key].add_item(torch.IntTensor(indices).detach().cpu()) + builders[key].end_document() + + builders[key].finalize( + f"{args.output_prefix}.idx", + ) + + log.info(f"Stored the .bin and .idx files in {args.output_prefix}") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input_videos_dir", required=True, type=str, help="The path to the input videos") + parser.add_argument( + "--encoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to encoder" + ) + parser.add_argument( + "--decoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to the decoder" + ) + parser.add_argument( + "--output_prefix", + required=True, + type=str, + help="The directory along with the output file name to write the .idx and .bin files (e.g /path/to/output/sample)", + ) + args = parser.parse_args() + + with torch.no_grad(): + main(args) diff --git a/cosmos1/models/autoregressive/nemo/utils.py b/cosmos1/models/autoregressive/nemo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ac558aecba5f67b208c5eeb649e9b915d4c201 --- /dev/null +++ b/cosmos1/models/autoregressive/nemo/utils.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import importlib +import math +import os +from typing import List + +import torch +import torchvision +from huggingface_hub import snapshot_download + +from inference_config import DiffusionDecoderSamplingConfig +from cosmos1.models.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens +from cosmos1.models.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from inference_utils import ( + load_network_model, + load_tokenizer_model, + skip_init_linear, +) +from .log import log +from config_helper import get_config_module, override + +TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] +DATA_RESOLUTION_SUPPORTED = [640, 1024] +NUM_CONTEXT_FRAMES = 33 + + +def resize_input(video: torch.Tensor, resolution: list[int]): + r""" + Function to perform aspect ratio preserving resizing and center cropping. + This is needed to make the video into target resolution. + Args: + video (torch.Tensor): Input video tensor + resolution (list[int]): Data resolution + Returns: + Cropped video + """ + + orig_h, orig_w = video.shape[2], video.shape[3] + target_h, target_w = resolution + + scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) + resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) + video_resized = torchvision.transforms.functional.resize(video, resizing_shape) + video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) + return video_cropped + + +def read_input_videos(input_video: str) -> torch.tensor: + """Utility to read the input video and return a torch tensor + + Args: + input_video (str): A path to .mp4 file + data_resolution (list, optional): The . Defaults to [640, 1024]. + + Returns: + A torch tensor of the video + """ + video, _, _ = torchvision.io.read_video(input_video) + video = video.float() / 255.0 + video = video * 2 - 1 + + if video.shape[0] > NUM_CONTEXT_FRAMES: + video = video[0:NUM_CONTEXT_FRAMES, :, :, :] + else: + log.info(f"Video doesn't have {NUM_CONTEXT_FRAMES} frames. Padding the video with the last frame.") + # Pad the video + nframes_in_video = video.shape[0] + video = torch.cat( + (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_CONTEXT_FRAMES - nframes_in_video, 1, 1, 1)), + dim=0, + ) + + video = video[0:NUM_CONTEXT_FRAMES, :, :, :] + video = video.permute(0, 3, 1, 2) + video = resize_input(video, DATA_RESOLUTION_SUPPORTED) + return video.transpose(0, 1).unsqueeze(0) + + +def run_diffusion_decoder_model(indices_tensor_cur_batch: List[torch.Tensor], out_videos_cur_batch): + """Run a 7b diffusion model to enhance generation output + + Args: + indices_tensor_cur_batch (List[torch.Tensor]): The index tensor(i.e) prompt + generation tokens + out_videos_cur_batch (torch.Tensor): The output decoded video of shape [bs, 3, 33, 640, 1024] + """ + diffusion_decoder_ckpt_path = snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8") + dd_tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") + tokenizer_corruptor_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-DV8x16x16") + + diffusion_decoder_model = load_model_by_config( + config_job_name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token", + config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py", + model_class=LatentDiffusionDecoderModel, + encoder_path=os.path.join(tokenizer_corruptor_dir, "encoder.jit"), + decoder_path=os.path.join(tokenizer_corruptor_dir, "decoder.jit"), + ) + load_network_model(diffusion_decoder_model, os.path.join(diffusion_decoder_ckpt_path, "model.pt")) + load_tokenizer_model(diffusion_decoder_model, dd_tokenizer_dir) + + generic_prompt = dict() + aux_vars = torch.load(os.path.join(diffusion_decoder_ckpt_path, "aux_vars.pt"), weights_only=True) + generic_prompt["context"] = aux_vars["context"].cuda() + generic_prompt["context_mask"] = aux_vars["context_mask"].cuda() + + output_video = diffusion_decoder_process_tokens( + model=diffusion_decoder_model, + indices_tensor=indices_tensor_cur_batch, + dd_sampling_config=DiffusionDecoderSamplingConfig(), + original_video_example=out_videos_cur_batch[0], + t5_emb_batch=[generic_prompt["context"]], + ) + + del diffusion_decoder_model + diffusion_decoder_model = None + gc.collect() + torch.cuda.empty_cache() + + return output_video + + +def load_model_by_config( + config_job_name, + config_file="projects/cosmos_video/config/config.py", + model_class=LatentDiffusionDecoderModel, + encoder_path=None, + decoder_path=None, +): + config_module = get_config_module(config_file) + config = importlib.import_module(config_module).make_config() + + config = override(config, ["--", f"experiment={config_job_name}"]) + + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + if encoder_path: + config.model.tokenizer_corruptor["enc_fp"] = encoder_path + if decoder_path: + config.model.tokenizer_corruptor["dec_fp"] = decoder_path + # Initialize model + with skip_init_linear(): + model = model_class(config.model) + return model diff --git a/cosmos1/models/autoregressive/tokenizer/__init__.py b/cosmos1/models/autoregressive/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/autoregressive/utils/__init__.py b/cosmos1/models/autoregressive/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/autoregressive/utils/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/autoregressive/utils/inference.py b/cosmos1/models/autoregressive/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d75386b2b2d5ded1ec3991e205dd2d5d998c7acd --- /dev/null +++ b/cosmos1/models/autoregressive/utils/inference.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +from pathlib import Path +from typing import List + +import numpy as np +import torch +import torchvision +from PIL import Image + +from inference_config import SamplingConfig +from .log import log + +_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] +_VIDEO_EXTENSIONS = [".mp4"] +_SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames +NUM_TOTAL_FRAMES = 33 + + +def add_common_arguments(parser): + """Add common command line arguments. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input path for input image or video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Input folder containing all input images or videos", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=9, + help="Number of input frames for world generation", + choices=_SUPPORTED_CONTEXT_LEN, + ) + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") + parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + parser.add_argument( + "--offload_diffusion_decoder", + action="store_true", + help="Offload diffusion decoder after inference", + ) + parser.add_argument( + "--offload_ar_model", + action="store_true", + help="Offload AR model after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload discrete tokenizer model after inference", + ) + + +def validate_args(args: argparse.Namespace, inference_type: str): + """Validate command line arguments for base and video2world generation.""" + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: + args.num_input_frames = 1 + log.info(f"Set num_input_frames to 1 for {args.input_type} input") + + if args.num_input_frames == 1: + if "4B" in args.ar_model_dir: + log.warning( + "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + elif "5B" in args.ar_model_dir: + log.warning( + "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + + # Validate prompt/image/video args for single or batch generation + assert ( + args.input_image_or_video_path or args.batch_input_path + ), "--input_image_or_video_path or --batch_input_path must be provided." + if inference_type == "video2world" and (not args.batch_input_path): + assert args.prompt, "--prompt is required for single video generation." + args.data_resolution = [640, 1024] + + # Validate number of GPUs + num_gpus = int(os.getenv("WORLD_SIZE", 1)) + assert num_gpus <= 1, "We support only single GPU inference for now" + + # Create output folder + Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) + + sampling_config = SamplingConfig( + echo=True, + temperature=args.temperature, + top_p=args.top_p, + compile_sampling=True, + ) + return sampling_config + + +def resize_input(video: torch.Tensor, resolution: list[int]): + r""" + Function to perform aspect ratio preserving resizing and center cropping. + This is needed to make the video into target resolution. + Args: + video (torch.Tensor): Input video tensor + resolution (list[int]): Data resolution + Returns: + Cropped video + """ + + orig_h, orig_w = video.shape[2], video.shape[3] + target_h, target_w = resolution + + scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) + resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) + video_resized = torchvision.transforms.functional.resize(video, resizing_shape) + video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) + return video_cropped + + +def load_image_from_list(flist, data_resolution: List[int]) -> dict: + """ + Function to load images from a list of image paths. + Args: + flist (List[str]): List of image paths + data_resolution (List[int]): Data resolution + Returns: + Dict containing input images + """ + all_videos = dict() + for img_path in flist: + ext = os.path.splitext(img_path)[1] + if ext in _IMAGE_EXTENSIONS: + # Read the image + img = Image.open(img_path) + + # Convert to tensor + img = torchvision.transforms.functional.to_tensor(img) + static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) + static_vid = static_vid * 2 - 1 + + log.debug( + f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + static_vid = resize_input(static_vid, data_resolution) + fname = os.path.basename(img_path) + all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input images from a JSONL file. + + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + + Returns: + Dict containing input images + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_image(input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input image. + Args: + input_path (str): Path to input image + data_resolution (List[int]): Data resolution + Returns: + Dict containing input image + """ + flist = [input_path] + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + r""" + Function to read input videos. + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + Returns: + Dict containing input videos + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to read input video. + Args: + input_path (str): Path to input video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input video + """ + flist = [input_path] + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to load videos from a list of video paths. + Args: + flist (List[str]): List of video paths + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + all_videos = dict() + + for video_path in flist: + ext = os.path.splitext(video_path)[-1] + if ext in _VIDEO_EXTENSIONS: + video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") + video = video.float() / 255.0 + video = video * 2 - 1 + + # Resize the videos to the required dimension + nframes_in_video = video.shape[0] + if nframes_in_video < num_input_frames: + fname = os.path.basename(video_path) + log.warning( + f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." + ) + continue + + video = video[-num_input_frames:, :, :, :] + + # Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES) + video = torch.cat( + (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), + dim=0, + ) + + video = video.permute(0, 3, 1, 2) + + log.debug( + f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + video = resize_input(video, data_resolution) + + fname = os.path.basename(video_path) + all_videos[fname] = video.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def load_vision_input( + input_type: str, + batch_input_path: str, + input_image_or_video_path: str, + data_resolution: List[int], + num_input_frames: int, +): + """ + Function to load vision input. + Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. + Args: + input_type (str): Type of input + batch_input_path (str): Folder containing input images or videos + input_image_or_video_path (str): Path to input image or video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + if batch_input_path: + log.info(f"Reading batch inputs from path: {batch_input_path}") + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_videos( + batch_input_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + else: + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_video( + input_image_or_video_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + return input_videos + + +def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: + """ + Function to convert output tensors to numpy format for saving. + Args: + video_batch (List[torch.Tensor]): List of output tensors + Returns: + List of numpy arrays + """ + return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] diff --git a/cosmos1/models/diffusion/README.md b/cosmos1/models/diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d088d8e4fc31cb4f5edd316e793f6372056db42c --- /dev/null +++ b/cosmos1/models/diffusion/README.md @@ -0,0 +1,408 @@ +# Cosmos Diffusion-based World Foundation Models + +## Table of Contents +- [Getting Started](#getting-started) + - [Set Up Docker Environment](#set-up-docker-environment) + - [Download Checkpoints](#download-checkpoints) +- [Usage](#usage) + - [Model Types](#model-types) + - [Single and Batch Generation](#single-and-batch-generation) + - [Sample Commands](#sample-commands) + - [Text2World](#text2world-text2worldpy-7b-and-14b) + - [Video2World](#video2world-video2worldpy-7b-and-14b) + - [Arguments](#arguments) + - [Common Parameters](#common-parameters) + - [Text2World Specific Parameters](#text2world-specific-parameters) + - [Video2World Specific Parameters](#video2world-specific-parameters) + - [Safety Features](#safety-features) + - [Prompting Instructions](#prompting-instructions) + +This page details the steps for using the Cosmos diffusion-based world foundation models. + +## Getting Started + +### Set Up Docker Environment + +Follow our [Installation Guide](../../../INSTALL.md) to set up the Docker environment. All commands on this page should be run inside Docker. + +### Download Checkpoints + +1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). + +2. Log in to Hugging Face with the access token: + +```bash +huggingface-cli login +``` + +3. Request access to Mistral AI's Pixtral-12B model by clicking on `Agree and access repository` on [Pixtral's Hugging Face model page](https://huggingface.co/mistralai/Pixtral-12B-2409). This step is required to use Pixtral 12B for the Video2World prompt upsampling task. + +4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6): + +```bash +PYTHONPATH=$(pwd) python cosmos1/scripts/download_diffusion.py --model_sizes 7B 14B --model_types Text2World Video2World +``` + +5. The downloaded files should be in the following structure: + +``` +checkpoints/ +├── Cosmos-1.0-Diffusion-7B-Text2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Diffusion-14B-Text2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Diffusion-7B-Video2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Diffusion-14B-Video2World +│ ├── model.pt +│ └── config.json +├── Cosmos-1.0-Tokenizer-CV8x8x8 +│ ├── decoder.jit +│ ├── encoder.jit +│ └── mean_std.pt +├── Cosmos-1.0-Prompt-Upsampler-12B-Text2World +│ ├── model.pt +│ └── config.json +├── Pixtral-12B +│ ├── model.pt +│ ├── config.json +└── Cosmos-1.0-Guardrail + ├── aegis/ + ├── blocklist/ + ├── face_blur_filter/ + └── video_content_safety_filter/ +``` + +## Usage + +### Model Types + +There are two model types available for diffusion world generation: + +1. **Text2World**: Supports world generation from text input + +* Models: `Cosmos-1.0-Diffusion-7B-Text2World` and `Cosmos-1.0-Diffusion-14B-Text2World` +* Inference script: [text2world.py](/cosmos1/models/diffusion/inference/text2world.py) + +2. **Video2World**: Supports world generation from text and image/video input + +* Models: `Cosmos-1.0-Diffusion-7B-Video2World` and `Cosmos-1.0-Diffusion-14B-Video2World` +* Inference script: [video2world.py](/cosmos1/models/diffusion/inference/video2world.py) + +### Single and Batch Generation + +We support both single and batch video generation. + +For generating a single video, `Text2World` mode requires the input argument `--prompt` (text input). `Video2World` mode requires `--input_image_or_video_path` (image/video input). Additionally for Video2World, if the prompt upsampler is disabled, a text prompt must also be provided using the `--prompt` argument. + +For generating a batch of videos, both `Text2World` and `Video2World` require `--batch_input_path` (path to a JSONL file). For `Text2World`, the JSONL file should contain one prompt per line in the following format, where each line must contain a "prompt" field: + +```json +{"prompt": "prompt1"} +{"prompt": "prompt2"} +``` + +For `Video2World`, each line in the JSONL file must contain a "visual_input" field: + +```json +{"visual_input": "path/to/video1.mp4"} +{"visual_input": "path/to/video2.mp4"} +``` + +If you disable the prompt upsampler by setting the `--disable_prompt_upsampler` flag, each line in the JSONL file will need to include both "prompt" and "visual_input" fields. + +```json +{"prompt": "prompt1", "visual_input": "path/to/video1.mp4"} +{"prompt": "prompt2", "visual_input": "path/to/video2.mp4"} +``` + +### Sample Commands + +There are two main demo scripts for diffusion world generation: `text2world.py` and `video2world.py`. Below you will find sample commands for single and batch generation, as well as commands for running with low-memory GPUs using model offloading. We also provide a memory usage table comparing different offloading strategies to help with configuration. + +#### Text2World (text2world.py): 7B and 14B + +Generates world from text input. + +##### Single Generation + +```bash +PROMPT="A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. \ +The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. \ +A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, \ +suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. \ +The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of \ +field that keeps the focus on the robot while subtly blurring the background for a cinematic effect." + +# Example using 7B model +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \ + --prompt "$PROMPT" \ + --offload_prompt_upsampler \ + --video_save_name Cosmos-1.0-Diffusion-7B-Text2World + +# Example using the 7B model on low-memory GPUs with model offloading. The speed is slower if using batch generation. +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \ + --prompt "$PROMPT" \ + --video_save_name Cosmos-1.0-Diffusion-7B-Text2World_memory_efficient \ + --offload_tokenizer \ + --offload_diffusion_transformer \ + --offload_text_encoder_model \ + --offload_prompt_upsampler \ + --offload_guardrail_models + +# Example using 14B model with prompt upsampler offloading (required on H100) +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-14B-Text2World \ + --prompt "$PROMPT" \ + --video_save_name Cosmos-1.0-Diffusion-14B-Text2World \ + --offload_prompt_upsampler \ + --offload_guardrail_models +``` + +##### Batch Generation + +```bash +# Example using 7B model +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \ + --batch_input_path cosmos1/models/diffusion/assets/v1p0/batch_inputs/text2world.jsonl \ + --video_save_folder outputs/Cosmos-1.0-Diffusion-7B-Text2World \ + --offload_prompt_upsampler +``` + +##### Example Output + +Here is an example output video generated using text2world.py: + + + +The upsampled prompt used to generate the video is: + +``` +In a sprawling, meticulously organized warehouse, a sleek humanoid robot stands sentinel amidst towering shelves brimming with neatly stacked cardboard boxes. The robot's metallic body, adorned with intricate joints and a glowing blue chest light, radiates an aura of advanced technology, its design a harmonious blend of functionality and futuristic elegance. The camera captures this striking figure in a static, wide shot, emphasizing its poised stance against the backdrop of industrial wooden pallets. The lighting is bright and even, casting a warm glow that accentuates the robot's form, while the shallow depth of field subtly blurs the rows of boxes, creating a cinematic depth that draws the viewer into this high-tech realm. The absence of human presence amplifies the robot's solitary vigil, inviting contemplation of its purpose within this vast, organized expanse. +``` + +If you disable the prompt upsampler by using the `--disable_prompt_upsampler` flag, the output video will be generated using the original prompt: + + + +The original prompt is: +``` +A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect. +``` + +Note that the robot face could be blurred sometimes by the guardrail in this example. + +##### Inference Time and GPU Memory Usage + +The numbers provided below may vary depending on system specs and are for reference only. + +We report the maximum observed GPU memory usage during end-to-end inference. Additionally, we offer a series of model offloading strategies to help users manage GPU memory usage effectively. + +For GPUs with limited memory (e.g., RTX 3090/4090 with 24 GB memory), we recommend fully offloading all models. For higher-end GPUs, users can select the most suitable offloading strategy considering the numbers provided below. + +| Offloading Strategy | 7B Text2World | 14B Text2World | +|-------------|---------|---------| +| Offload prompt upsampler | 74.0 GB | > 80.0 GB | +| Offload prompt upsampler & guardrails | 57.1 GB | 70.5 GB | +| Offload prompt upsampler & guardrails & T5 encoder | 38.5 GB | 51.9 GB | +| Offload prompt upsampler & guardrails & T5 encoder & tokenizer | 38.3 GB | 51.7 GB | +| Offload prompt upsampler & guardrails & T5 encoder & tokenizer & diffusion model | 24.4 GB | 39.0 GB | + +The table below presents the end-to-end inference runtime on a single H100 GPU, excluding model initialization time. + +| 7B Text2World (offload prompt upsampler) | 14B Text2World (offload prompt upsampler, guardrails) | +|---------|---------| +| ~380 seconds | ~590 seconds | + +#### Video2World (video2world.py): 7B and 14B + +Generates world from text and image/video input. + +##### Single Generation + +Note that our prompt upsampler is enabled by default for Video2World, and it will generate the prompt from the input image/video. If the prompt upsampler is disabled, you can provide a prompt manually using the `--prompt` flag. + +```bash +# Example using the 7B model +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/video2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \ + --input_image_or_video_path cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg \ + --num_input_frames 1 \ + --video_save_name Cosmos-1.0-Diffusion-7B-Video2World \ + --offload_prompt_upsampler + +# Example using the 7B model on low-memory GPUs with model offloading. The speed is slower if using batch generation. +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/video2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \ + --input_image_or_video_path cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg \ + --num_input_frames 1 \ + --video_save_name Cosmos-1.0-Diffusion-7B-Video2World_memory_efficient \ + --offload_tokenizer \ + --offload_diffusion_transformer \ + --offload_text_encoder_model \ + --offload_prompt_upsampler \ + --offload_guardrail_models + +# Example using 14B model with prompt upsampler offloading (required on H100) +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/video2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-14B-Video2World \ + --input_image_or_video_path cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg \ + --num_input_frames 1 \ + --video_save_name Cosmos-1.0-Diffusion-14B-Video2World \ + --offload_prompt_upsampler \ + --offload_guardrail_models +``` + +##### Batch Generation + +```bash +# Example using 7B model with 9 input frames +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/video2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \ + --batch_input_path cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_ps.jsonl \ + --video_save_folder outputs/Cosmos-1.0-Diffusion-7B-Video2World \ + --offload_prompt_upsampler \ + --num_input_frames 9 + +# Example using 7B model with 9 input frames without prompt upsampler, using 'prompt' field in the JSONL file +PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/video2world.py \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \ + --batch_input_path cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_wo_ps.jsonl \ + --video_save_folder outputs/Cosmos-1.0-Diffusion-7B-Video2World_wo_ps \ + --disable_prompt_upsampler \ + --num_input_frames 9 +``` + +##### Example Output + +Here is an example output video generated using video2world.py, using `Cosmos-1.0-Diffusion-14B-Video2World`: + + + +The upsampled prompt (generated by the prompt upsampler) used to generate the video is: + +``` +The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. +``` + +##### Inference Time and GPU Memory Usage + +The numbers provided below may vary depending on system specs and are for reference only. + +| Offloading Strategy | 7B Video2World | 14B Video2World | +|----------------------------------------------------------------------------------|---------|---------| +| Offload prompt upsampler | 76.5 GB | > 80.0 GB | +| Offload prompt upsampler & guardrails | 59.9 GB | 73.3 GB | +| Offload prompt upsampler & guardrails & T5 encoder | 41.3 GB | 54.8 GB | +| Offload prompt upsampler & guardrails & T5 encoder & tokenizer | 41.1 GB | 54.5 GB | +| Offload prompt upsampler & guardrails & T5 encoder & tokenizer & diffusion model | 27.3 GB | 39.0 GB | + +The following table shows the end-to-end inference runtime on a single H100 GPU, excluding model initialization time: + +| 7B Video2World (offload prompt upsampler) | 14B Video2World (offload prompt upsampler, guardrails) | +|---------|---------| +| ~383 seconds | ~593 seconds | + +### Arguments + +#### Common Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--checkpoint_dir` | Directory containing model weights | "checkpoints" | +| `--tokenizer_dir` | Directory containing tokenizer weights | "Cosmos-1.0-Tokenizer-CV8x8x8" | +| `--video_save_name` | Output video filename for single video generation | "output" | +| `--video_save_folder` | Output directory for batch video generation | "outputs/" | +| `--prompt` | Text prompt for single video generation. Required for single video generation. | None | +| `--batch_input_path` | Path to JSONL file for batch video generation. Required for batch video generation. | None | +| `--negative_prompt` | Negative prompt for improved quality | "The video captures a series of frames showing ugly scenes..." | +| `--num_steps` | Number of diffusion sampling steps | 35 | +| `--guidance` | CFG guidance scale | 7.0 | +| `--num_video_frames` | Number of frames to generate | 121 | +| `--height` | Output video height | 704 | +| `--width` | Output video width | 1280 | +| `--fps` | Frames per second | 24 | +| `--seed` | Random seed | 1 | +| `--disable_prompt_upsampler` | Disable automatic prompt enhancement | False | +| `--offload_diffusion_transformer` | Offload DiT model after inference, used for low-memory GPUs | False | +| `--offload_tokenizer` | Offload VAE model after inference, used for low-memory GPUs | False | +| `--offload_text_encoder_model` | Offload text encoder after inference, used for low-memory GPUs | False | +| `--offload_prompt_upsampler` | Offload prompt upsampler after inference, used for low-memory GPUs | False | +| `--offload_guardrail_models` | Offload guardrail models after inference, used for low-memory GPUs | False | + +Note: we support various aspect ratios, including 1:1 (960x960 for height and width), 4:3 (960x704), 3:4 (704x960), 16:9 (1280x704), and 9:16 (704x1280). The frame rate is also adjustable within a range of 12 to 40 fps. + +#### Text2World Specific Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--diffusion_transformer_dir` | Directory containing DiT weights | "Cosmos-1.0-Diffusion-7B-Text2World" | +| `--prompt_upsampler_dir` | Directory containing prompt upsampler weights | "Cosmos-1.0-Prompt-Upsampler-12B-Text2World" | +| `--word_limit_to_skip_upsampler` | Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value | 250 | +#### Video2World Specific Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `--diffusion_transformer_dir` | Directory containing DiT weights | "Cosmos-1.0-Diffusion-7B-Video2World" | +| `--prompt_upsampler_dir` | Directory containing prompt upsampler weights | "Pixtral-12B" | +| `--input_image_or_video_path` | Input video/image path for single video generation. Required for single video generation. | None | +| `--num_input_frames` | Number of video frames (1 or 9) | 1 | + +### Safety Features + +The model uses a built-in safety guardrail system that cannot be disabled. Generating human faces is not allowed and will be blurred by the guardrail. + +For more information, check out the [Cosmos Guardrail Documentation](../guardrail/README.md). + +### Prompting Instructions + +The input prompt is the most important parameter under the user's control when interacting with the model. Providing rich and descriptive prompts can positively impact the output quality of the model, whereas short and poorly detailed prompts can lead to subpar video generation. Here are some recommendations to keep in mind when crafting text prompts for the model: + +1. **Describe a single, captivating scene**: Focus on a single scene to prevent the model from generating videos with unnecessary shot changes. +2. **Limit camera control instructions**: The model doesn't handle prompts involving camera control well, as this feature is still under development. +3. **Prompt upsampler limitations**: The current version of the prompt upsampler may sometimes deviate from the original intent of your prompt, adding unwanted details. If this happens, you can disable the upsampler with the --disable_prompt_upsampler flag and edit your prompt manually. We recommend using prompts of around 120 words for optimal quality. + +#### Cosmos-1.0-Prompt-Upsampler + +The prompt upsampler automatically expands brief prompts into more detailed descriptions (Text2World) or generates detailed prompts based on input images (Video2World). + +##### Text2World + +When enabled (default), the upsampler will: + +1. Take your input prompt +2. Process it through a finetuned Mistral model to generate a more detailed description +3. Use the expanded description for video generation + +This can help generate better quality videos by providing more detailed context to the video generation model. To disable this feature, use the `--disable_prompt_upsampler` flag. + +##### Video2World + +When enabled (default), the upsampler will: + +1. Take your input image or video +2. Process it through a Pixtral model to generate a detailed description +3. Use the generated description for video generation + +Please note that the Video2World prompt upsampler does not consider any user-provided text prompt. To disable this feature, use the `--disable_prompt_upsampler` flag. diff --git a/cosmos1/models/diffusion/assets/nemo/text2world_example_after_finetune.mp4 b/cosmos1/models/diffusion/assets/nemo/text2world_example_after_finetune.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..58926f58c86ed3ba27396e52ec32cbe9404e1f12 Binary files /dev/null and b/cosmos1/models/diffusion/assets/nemo/text2world_example_after_finetune.mp4 differ diff --git a/cosmos1/models/diffusion/assets/v1p0/batch_inputs/text2world.jsonl b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/text2world.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..38fc3aec514c0e62fd6122a1f773d7af6e8f84e1 --- /dev/null +++ b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/text2world.jsonl @@ -0,0 +1,2 @@ +{"prompt": "A tugger train glides smoothly through the factory floor. Its trailers are filled with a mix of neatly arranged crates and humanoid robots, each securely positioned and ready for delivery. The robots, sleek and functional in design, sit upright alongside the crates, emphasizing their integration into the production process. The train moves along a marked path, highlighting the organized and efficient flow of materials and automation in the bustling environment. The crates appear carefully packed, and the robots\u2019 purposeful placement suggests their importance in streamlining operations."} +{"prompt": "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."} diff --git a/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_ps.jsonl b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_ps.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..b378bd806d0aec4050e1b25b0d6c045e8cdc2c0b --- /dev/null +++ b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_ps.jsonl @@ -0,0 +1,2 @@ +{"visual_input": "cosmos1/models/diffusion/assets/v1p0/video2world_input1.mp4"} +{"visual_input": "cosmos1/models/diffusion/assets/v1p0/video2world_input2.mp4"} diff --git a/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_wo_ps.jsonl b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_wo_ps.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..80adecf46e754b39063f8cb6f97819aa3111b254 --- /dev/null +++ b/cosmos1/models/diffusion/assets/v1p0/batch_inputs/video2world_wo_ps.jsonl @@ -0,0 +1,2 @@ +{"prompt": "The video depicts a vehicle driving along a sandy beach, leaving a trail of sand in its wake. The car is moving parallel to the shoreline, with the ocean waves gently lapping at the sand. The scene captures the essence of a serene coastal drive, emphasizing the interaction between the vehicle and the natural environment.", "visual_input": "cosmos1/models/diffusion/assets/v1p0/video2world_input1.mp4"} +{"prompt": "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region.", "visual_input": "cosmos1/models/diffusion/assets/v1p0/video2world_input2.mp4"} diff --git a/cosmos1/models/diffusion/assets/v1p0/text2world_example.mp4 b/cosmos1/models/diffusion/assets/v1p0/text2world_example.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bb0b15ce33aed08ec43326fe4d777a3f714496d7 Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/text2world_example.mp4 differ diff --git a/cosmos1/models/diffusion/assets/v1p0/text2world_example2.mp4 b/cosmos1/models/diffusion/assets/v1p0/text2world_example2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c593fb768cbce22a5c0b6d4b31c639c841178089 Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/text2world_example2.mp4 differ diff --git a/cosmos1/models/diffusion/assets/v1p0/video2world_example.mp4 b/cosmos1/models/diffusion/assets/v1p0/video2world_example.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3f3d51a78c84e5dac64924910f9a7445c6f3b362 Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/video2world_example.mp4 differ diff --git a/cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg b/cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..108b718d810b1fdff547206e1ab240c40bb23025 Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg differ diff --git a/cosmos1/models/diffusion/assets/v1p0/video2world_input1.mp4 b/cosmos1/models/diffusion/assets/v1p0/video2world_input1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..54e752e51fd3fff6c76dc164268c482920485e9d Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/video2world_input1.mp4 differ diff --git a/cosmos1/models/diffusion/assets/v1p0/video2world_input2.mp4 b/cosmos1/models/diffusion/assets/v1p0/video2world_input2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..adaffe21e9c7d23937c12bd6cdf63705370a3289 Binary files /dev/null and b/cosmos1/models/diffusion/assets/v1p0/video2world_input2.mp4 differ diff --git a/cosmos1/models/diffusion/config/__init__.py b/cosmos1/models/diffusion/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/diffusion/config/base/__init__.py b/cosmos1/models/diffusion/config/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/diffusion/config/inference/__init__.py b/cosmos1/models/diffusion/config/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/diffusion/module/__init__.py b/cosmos1/models/diffusion/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/diffusion/nemo/__init__.py b/cosmos1/models/diffusion/nemo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b11bcc2d6929f824fade4d22634d2c9a8c7811a7 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/diffusion/nemo/download_diffusion_nemo.py b/cosmos1/models/diffusion/nemo/download_diffusion_nemo.py new file mode 100644 index 0000000000000000000000000000000000000000..61d014c9add031f93870d4ee9d81768afd45674f --- /dev/null +++ b/cosmos1/models/diffusion/nemo/download_diffusion_nemo.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from huggingface_hub import snapshot_download + + +def download_diffusion_nemo(): + """ + Downloads all Cosmos Diffusion NeMo assets to HF_HOME directory. + Make sure to set HF_HOME to your desired path before running this function. + """ + snapshot_download("nvidia/Cosmos-1.0-Guardrail") + snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") + snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", allow_patterns=["nemo/*"]) + snapshot_download("nvidia/Cosmos-1.0-Diffusion-14B-Text2World", allow_patterns=["nemo/*"]) + snapshot_download("nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World") + snapshot_download("google-t5/t5-11b", ignore_patterns=["*.h5"]) + + +def main(): + # Check if HF_HOME is set + hf_home = os.environ.get("HF_HOME") + if not hf_home: + raise EnvironmentError( + "The HF_HOME environment variable is not set. " + "Please set it to your desired path before running this script." + ) + + # Download Cosmos Diffusion NeMo checkpoints + download_diffusion_nemo() + + +if __name__ == "__main__": + main() diff --git a/cosmos1/models/diffusion/nemo/inference/README.md b/cosmos1/models/diffusion/nemo/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fd00b4ff688ed0bdc1658241f0f86a12dbb787a3 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/inference/README.md @@ -0,0 +1,185 @@ +# Cosmos Diffusion-based World Foundation Models: NeMo Framework User Guide + +Learn how to [run inference](#inference) with Cosmos Diffusion-based World Foundation Models (WFMs) using the [NVIDIA NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html) for your custom Physical AI tasks by following this guide. + +## Model Support Matrix + +The NeMo Framework supports the following Cosmos Diffusion models. Review the available models and their compute requirements for post-tuning and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Inference | Multi-GPU Support +|----------------------------------------------|------------------|------------------------------------------|---------| +| Cosmos-1.0-Diffusion-7B-Text2World | **Supported** | 1 NVIDIA GPU* | **Supported** | +| Cosmos-1.0-Diffusion-14B-Text2World | **Supported** | 1 NVIDIA GPU* | **Supported** | +| Cosmos-1.0-Diffusion-7B-Video2World | **Coming Soon** | | | +| Cosmos-1.0-Diffusion-14B-Video2WorldB | **Coming Soon** | | | + + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +## Post-Trained Model Inference Support Matrix + +Cosmos Diffusion-based WFMs can also be post-trained for a variety of Physical AI tasks and used for inference. Review the following table for a list of available Physical AI post-training tasks: + +| Post-training Task | Inference Support Status | +|-------------------------|--------------------| +| General post-training | **Supported** | +| Instruction control | **Coming Soon** | +| Action control | **Coming Soon** | +| Camera control | **Coming Soon** | +| Multi-view generation | **Coming Soon** | +| Multi-view generation with vehicle trajectory control | **Coming Soon** | + +## Prerequisites + +### 1. Review General Requirements + +- System Configuration + - **NVIDIA GPU and driver**: Ensure you have access to the minimum compute required to run the model(s), as listed in the model support matrix. + - **Containerization Platform**: We recommend using Docker with NVIDIA Container Runtime (alternatively, you may use NVIDIA enroot). +- Get your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens), which is required to obtain the Cosmos models for training and inference. +- Get your [Weights and Biases API Key](https://docs.wandb.ai/support/find_api_key/) for logging and tracking. + +### 2. Clone the Cosmos Repository + +```bash +git clone git@github.com:NVIDIA/Cosmos.git +``` + +### 3. Start the Container + +The [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) supports post-training and inference for Cosmos Diffusion models. + +Run the following command to download and start the container: +```bash +docker run --ipc=host -it --gpus=all \ + -v $PATH_TO_COSMOS_REPO:/workspace/Cosmos \ + nvcr.io/nvidia/nemo:cosmos.1.0 bash +``` + +### 4. Download Checkpoints + +To help you get started, we've provided a [download script](../download_diffusion_nemo.py) to get the Cosmos Diffusion checkpoints from Hugging Face. These checkpoints are in the NeMo distributed checkpoint format required to run post-training and inference with NeMo Framework. + +1. Set the following environment variables: + ```bash + # You must set HF_HOME before running this script. + export HF_TOKEN="" + export HF_HOME="" + ``` +2. Run the following command to download the models: + ```bash + cd /workspace/Cosmos + python cosmos1/models/diffusion/nemo/download_diffusion_nemo.py + ``` + +## Run Inference + +Running inference with Cosmos Diffusion models lets you generate a video conditioned on a text prompt. + +Our inference script enables accelerated world generation with context parallel. We use context parallelism to split the diffusion process across multiple GPUs, providing near-linear scaling efficiency. Our diffusion pipeline also allows the user to set a variety of hyperparameters including the random seed, classifier-free guidance scale, negative prompt, video resolution, and video fps. + +General post-training is essentially a continuation of pre-training. To perform inference with models that have been post-trained with general post-training, you can set the `subject_name` parameter to the subject the model was post-trained on. The `prompt` parameter is then used to describe the setting and events in the generated world. The final prompt will be "A video of sks `{subject_name}`. `{prompt}`". We can also use [inference/general.py](./general.py) to perform inference on the base model since the model architecture is the same as the general post-trained models. + +We also provide the option to upsample the `prompt` and make it more detailed. This can improve the quality of the generated world. + +### Run the Inference Script with Base Model + +Complete the following steps to generate a new output video of a robot cooking. + +1. Set the following environment variables: + ```bash + # HuggingFace Cache to save T5 text encoder, video tokenizer, prompt upsampler, and guardrails weights. + export HF_TOKEN="" + export HF_HOME="" + + # Number of GPU devices available for inference. Supports up to 8 GPUs for accelerated inference. + export NUM_DEVICES=1 + + # Prompt describing world scene and actions taken by subject (if provided). + export PROMPT="The teal robot is cooking food in a kitchen. Steam rises from a simmering pot as the robot chops vegetables on a worn wooden cutting board. Copper pans hang from an overhead rack, catching glints of afternoon light, while a well-loved cast iron skillet sits on the stovetop next to scattered measuring spoons and a half-empty bottle of olive oil." + ``` +2. Run the following command: + ```bash + NVTE_FUSED_ATTN=0 \ + torchrun --nproc_per_node=$NUM_DEVICES cosmos1/models/diffusion/nemo/inference/general.py \ + --model Cosmos-1.0-Diffusion-7B-Text2World \ + --cp_size $NUM_DEVICES \ + --num_devices $NUM_DEVICES \ + --video_save_path "Cosmos-1.0-Diffusion-7B-Text2World.mp4" \ + --guidance 7 \ + --seed 1 \ + --prompt "$PROMPT" \ + --enable_prompt_upsampler + ``` + +### Run the Inference Script with Post-trained Model + +Complete the following steps to generate a new output video from the model post-trained with general post-training. + +1. Set the following environment variables: + ```bash + # HuggingFace Cache to save T5 text encoder, video tokenizer, prompt upsampler, and guardrails weights. + export HF_TOKEN="" + export HF_HOME="" + + # Inference with post-trained model. Find post-trained model under nemo_experiments. Example path: + export NEMO_CHECKPOINT=nemo_experiments/cosmos_diffusion_7b_text2world_finetune/default/2024-12-17_01-28-03/checkpoints/epoch=39-step=199/weights + + # Number of GPU devices available for inference. Supports up to 8 GPUs for accelerated inference. + export NUM_DEVICES=1 + + # Prompt describing world scene and actions taken by subject (if provided). + export PROMPT="The teal robot is cooking food in a kitchen. Steam rises from a simmering pot as the robot chops vegetables on a worn wooden cutting board. Copper pans hang from an overhead rack, catching glints of afternoon light, while a well-loved cast iron skillet sits on the stovetop next to scattered measuring spoons and a half-empty bottle of olive oil." + ``` +2. Run the following command: + ```bash + NVTE_FUSED_ATTN=0 \ + torchrun --nproc_per_node=8 cosmos1/models/diffusion/nemo/inference/general.py \ + --model Cosmos-1.0-Diffusion-7B-Text2World \ + --nemo_checkpoint "$NEMO_CHECKPOINT" \ + --cp_size $NUM_DEVICES \ + --num_devices $NUM_DEVICES \ + --video_save_path "Cosmos-1.0-Diffusion-7B-Text2World.mp4" \ + --guidance 7 \ + --seed 1 \ + --prompt "$PROMPT" \ + --subject_name "teal robot" \ + --enable_prompt_upsampler + ``` + +#### Example Output + +The following output is an example video generated from the post-trained model using [`general.py`](./inference/general.py): + + + +Generated videos are saved at the location configured in the `SAVE_PATH` parameter. + +> **Tip**: +> For faster inference, you can remove the `--enable_prompt_upsampler` parameter, but this may degrade the generated result. + +> **Disclaimer**: +> The post-training example in this documentation is a demonstration of general post-training and not a guaranteed recipe for success. Post-training outcomes depend heavily on the quality and diversity of the dataset. To achieve good results, ensure your dataset is clean, well-structured, diverse, and properly labeled. Poorly prepared data can lead to issues like overfitting, bias, or poor performance. Carefully curate your dataset to reflect the desired use case for reliable results. + +### Configuration Options + +The following table details the parameters that can be modified for accelerated inference with NeMo. You can adjust these parameters to optimize performance based on your specific requirements. The model inference hyperparameters listed below have the same functionality as in [Cosmos Diffusion Common Parameters](cosmos1/models/diffusion/README.md#parameters). + + +| Parameter | Description | Default | +|--------------------------------|---------------------------------------------------------------------------------|---------| +| `--model` | Name of Cosmos Text2World Diffusion model to use for inference. | `Cosmos-1.0-Diffusion-7B-Text2World` | +| `--prompt` | Prompt which the sampled video is conditioned on. Tries to generate what is mentioned in the prompt. | *None* (user must provide) | +| `--negative_prompt` | Negative prompt for improved quality | "The video captures a series of frames showing ugly scenes..." | +| `--subject_name` | Name of the subject the model was post-trained on. This can be left empty for base model inference. | *None* | +| `--guidance` | A control mechanism that determines how closely the model follows specified conditions (prompt) during the generation process. We recommend starting with a guidance of 7 and increasing it later if necessary. | 7 | +| `--sampler` | Sampling method used for generation. Only supports **RES** sampler from [this paper](https://arxiv.org/pdf/2308.02157). | `RES` | +| `--video_save_path` | Location to save generated videos. | `Cosmos-1.0-Diffusion-7B-Text2World.mp4` | +| `--fps` | Frames-per-second of generated video. Cosmos Diffusion models generate videos at 24 FPS by default. | 24 | +| `--height` | Height of the generated video. Set to 704 pixels by default, which is the largest supported video height for Cosmos Diffusion. | 704 | +| `--width` | Width of the generated video. Set to 1280 pixels by default, which is the largest supported video width for Cosmos Diffusion. | 1280 | +| `--seed` | Random seed for generating initial noise sample. Changing this will create a different video for the same prompt. Keep the seed fixed to maintain deterministic video generations. | 1 | +| `--num_devices` | [1–8] Number of GPUs to use in parallel for inference. | 8 | +| `--cp_size` | [1–8] Number of context parallel ranks to spawn for parallelized inference. Must be equal to `num_devices`. | 8 | diff --git a/cosmos1/models/diffusion/nemo/inference/__init__.py b/cosmos1/models/diffusion/nemo/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b11bcc2d6929f824fade4d22634d2c9a8c7811a7 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/diffusion/nemo/inference/general.py b/cosmos1/models/diffusion/nemo/inference/general.py new file mode 100644 index 0000000000000000000000000000000000000000..8c52e3abfb395dc13ec5ddc22ce562dccabb7a78 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/inference/general.py @@ -0,0 +1,353 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import numpy as np +import torch +from huggingface_hub import snapshot_download +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from nemo import lightning as nl +from nemo.lightning.megatron_parallel import MegatronParallel + +MegatronParallel.init_ddp = lambda self: None +from nemo.collections.diffusion.mcore_parallel_utils import Utils +from nemo.collections.diffusion.sampler.conditioner import VideoConditioner +from nemo.collections.diffusion.sampler.conditioner_configs import ( + FPSConfig, + ImageSizeConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, +) +from nemo.collections.diffusion.sampler.cosmos.cosmos_diffusion_pipeline import CosmosDiffusionPipeline +from transformers import T5EncoderModel, T5TokenizerFast + +from cosmos1.models.diffusion.nemo.inference.inference_utils import process_prompt, save_video +from .log import log + +EXAMPLE_PROMPT = ( + "The teal robot is cooking food in a kitchen. Steam rises from a simmering pot " + "as the robot chops vegetables on a worn wooden cutting board. Copper pans hang " + "from an overhead rack, catching glints of afternoon light, while a well-loved " + "cast iron skillet sits on the stovetop next to scattered measuring spoons and " + "a half-empty bottle of olive oil." +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Video foundation model inference") + parser.add_argument( + "--model", + type=str, + default="Cosmos-1.0-Diffusion-7B-Text2World", + choices=["Cosmos-1.0-Diffusion-7B-Text2World", "Cosmos-1.0-Diffusion-14B-Text2World"], + ) + parser.add_argument( + "--prompt", + type=str, + default=EXAMPLE_PROMPT, + help="Prompt which the sampled video condition on", + ) + # We turn on negative prompt by default. set to "" to turn it off. + parser.add_argument( + "--negative_prompt", + type=str, + default=( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." + ), + help="Negative prompt which the sampled video condition on", + ) + parser.add_argument("--subject_name", type=str, default="", help="Name of fine-tuned subject") + parser.add_argument("--guidance", type=float, default=7, help="Classifier-free guidance scale") + parser.add_argument("--sampler", type=str, default="RES", help="Currently only supports RES sampler.") + parser.add_argument("--video_save_path", type=str, default="outputs", help="Path to save the video") + parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") + parser.add_argument("--height", type=int, default=704, help="Height of image to sample") + parser.add_argument("--width", type=int, default=1280, help="Width of image to sample") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--num_devices", type=int, default=1, help="Number of devices for inference") + parser.add_argument("--cp_size", type=int, default=1, help="Number of cp ranks for multi-gpu inference.") + parser.add_argument("--num_steps", type=float, default=35, help="Number of diffusion sampling steps") + parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample") + parser.add_argument("--tokenizer_dir", type=str, default="", help="Directory for video tokenizer") + parser.add_argument("--cosmos_assets_dir", type=str, default="", help="Directory containing cosmos assets") + parser.add_argument("--prompt_upsampler_dir", type=str, default="", help="Prompt upsampler weights directory") + parser.add_argument("--guardrail_dir", type=str, default="", help="Guardrails weights directory") + parser.add_argument("--nemo_checkpoint", type=str, default="", help="Video diffusion model nemo weights") + parser.add_argument("--t5_cache_dir", type=str, default=None, help="Path to T5 model") + parser.add_argument( + "--enable_prompt_upsampler", action="store_true", help="Whether to use prompt upsampling before generation" + ) + + args = parser.parse_args() + return args + + +def print_rank_0(string: str): + rank = torch.distributed.get_rank() + if rank == 0: + log.info(string) + + +@torch.no_grad() +def encode_for_batch(tokenizer: T5TokenizerFast, encoder: T5EncoderModel, prompts: list[str], max_length: int = 512): + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + return encoded_text + + +def init_video_tokenizer(args): + """ + Initializes video tokenizer based on specified video tokenizer config / path. + """ + from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig + + vae_path = os.path.join(args.cosmos_assets_dir, args.tokenizer_dir) + if "7b" in args.nemo_checkpoint.lower(): + dit_config = DiT7BConfig(vae_path=vae_path) + if "14b" in args.nemo_checkpoint.lower(): + dit_config = DiT14BConfig(vae_path=vae_path) + vae = dit_config.configure_vae() + return vae + + +def check_prompt(args): + prompt = args.prompt + subject_string = None + if args.subject_name: + subject_string = f"A video of sks {args.subject_name}" + + prompt = process_prompt( + prompt=prompt, + checkpoint_dir=args.cosmos_assets_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + guardrails_dir=args.guardrail_dir, + enable_prompt_upsampler=args.enable_prompt_upsampler, + ) + + if subject_string: + prompt = f"{subject_string}. {prompt}" + return prompt + + +def prepare_data_batch(args, vae, t5_embeding_max_length=512): + tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir) + text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + + # Encode text to T5 embedding + out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0] + encoded_text = torch.tensor(out, dtype=torch.bfloat16) + + # Padding T5 embedding to t5_embeding_max_length + L, C = encoded_text.shape + t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16) + t5_embed[0, :L] = encoded_text + + if args.negative_prompt: + out = encode_for_batch(tokenizer, text_encoder, [args.negative_prompt])[0] + + encoded_text = torch.tensor(out, dtype=torch.bfloat16) + # Padding T5 embedding to t5_embeding_max_length + L, C = encoded_text.shape + neg_t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16) + neg_t5_embed[0, :L] = encoded_text + else: + neg_t5_embed = None + + # Prepare data sample + t, h, w = args.num_video_frames, args.height, args.width + state_shape = [ + vae.channel, + vae.get_latent_num_frames(t), + h // vae.spatial_compression_factor, + w // vae.spatial_compression_factor, + ] + + data_batch = { + "video": torch.zeros((1, 3, t, h, w), dtype=torch.uint8).cuda(), + "t5_text_embeddings": t5_embed, + "t5_text_mask": torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16).cuda(), + # other conditions + "image_size": torch.tensor( + [[args.height, args.width, args.height, args.width]] * 1, dtype=torch.bfloat16 + ).cuda(), + "fps": torch.tensor([args.fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor([args.num_video_frames] * 1, dtype=torch.bfloat16).cuda(), + "padding_mask": torch.zeros((1, 1, args.height, args.width), dtype=torch.bfloat16).cuda(), + } + if args.negative_prompt: + data_batch["neg_t5_text_embeddings"] = neg_t5_embed + data_batch["neg_t5_text_mask"] = torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16) + + return data_batch, state_shape + + +def setup_diffusion_pipeline(args): + """ + Initialize DiT model, parallel strategy, and diffusion pipeline for inference. + """ + # Initialize DiT model + from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig, DiTModel + + if "7b" in args.nemo_checkpoint.lower(): + dit_config = DiT7BConfig() + if "14b" in args.nemo_checkpoint.lower(): + dit_config = DiT14BConfig() + + dit_model = DiTModel(dit_config) + + # Initialize model parallel strategy. Here, we only use context parallel. + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=args.cp_size, + pipeline_dtype=torch.bfloat16, + ) + + # Initialize ptl trainer + trainer = nl.Trainer( + devices=args.num_devices, # you can change the numebr of devices to suit your setup + max_steps=1, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + # Convert trainer to fabric for inference + fabric = trainer.to_fabric() + fabric.strategy.checkpoint_io.save_ckpt_format = "zarr" + fabric.strategy.checkpoint_io.validate_access_integrity = False + model = fabric.load_model(args.nemo_checkpoint, dit_model).to(device="cuda", dtype=torch.bfloat16) + + # Set up diffusion pipeline + conditioner = VideoConditioner( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + ) + diffusion_pipeline = CosmosDiffusionPipeline( + net=model.module, conditioner=conditioner, sampler_type=args.sampler, seed=args.seed + ) + + return diffusion_pipeline + + +def run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline): + # prepare data + data_batch = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data_batch.items()} + data_batch["inference_fwd"] = True + sample = diffusion_pipeline.generate_samples_from_batch( + data_batch, + guidance=args.guidance, + state_shape=state_shape, + num_steps=args.num_steps, + is_negative_prompt=True if "neg_t5_text_embeddings" in data_batch else False, + ) + + rank = torch.distributed.get_rank() + if rank == 0: + # Post-processing and save video + sigma_data = 0.5 + grid = (1.0 + vae.decode(sample / sigma_data)).clamp(0, 2) / 2 + grid = (grid[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + save_video( + grid=grid, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=args.video_save_path, + checkpoint_dir=args.cosmos_assets_dir, + guardrails_dir=args.guardrail_dir, + ) + print_rank_0(f"saved video to {args.video_save_path}!") + + +def main(args): + if args.guardrail_dir == "": + args.guardrail_dir = snapshot_download("nvidia/Cosmos-1.0-Guardrail") + if args.tokenizer_dir == "": + args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") + if args.prompt_upsampler_dir == "" and args.enable_prompt_upsampler: + args.prompt_upsampler_dir = snapshot_download("nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World") + if args.nemo_checkpoint == "": + args.nemo_checkpoint = snapshot_download(f"nvidia/{args.model}", allow_patterns=["nemo/*"]) + args.nemo_checkpoint = os.path.join(args.nemo_checkpoint, "nemo") + + # Initialize megatron model parallel environment + Utils.initialize_distributed(1, 1, context_parallel_size=args.cp_size) + model_parallel_cuda_manual_seed(args.seed) + + args.prompt = check_prompt(args) + + # Load video tokenizer + print_rank_0("initializing video tokenizer...") + vae = init_video_tokenizer(args) + + # Prepare data batch + print_rank_0("preparing data batch...") + data_batch, state_shape = prepare_data_batch(args, vae) + + # Setup model / diffusion pipeline + print_rank_0("setting up diffusion pipeline...") + diffusion_pipeline = setup_diffusion_pipeline(args) + + # Generate video from prompt + print_rank_0("generating video...") + run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos1/models/diffusion/nemo/inference/inference_utils.py b/cosmos1/models/diffusion/nemo/inference/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..adf7b2e9b0d0267cd98df33c069d187af932b76d --- /dev/null +++ b/cosmos1/models/diffusion/nemo/inference/inference_utils.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import imageio +import numpy as np +import torch + +from ar_model import AutoRegressiveModel +from text2world_prompt_upsampler_inference import ( + create_prompt_upsampler, + run_chat_completion, +) +from presets import ( + create_text_guardrail_runner, + create_video_guardrail_runner, + run_text_guardrail, + run_video_guardrail, +) +from .log import log + + +def get_upsampled_prompt( + prompt_upsampler_model: AutoRegressiveModel, input_prompt: str, temperature: float = 0.01 +) -> str: + """ + Get upsampled prompt from the prompt upsampler model instance. + + Args: + prompt_upsampler_model: The prompt upsampler model instance. + input_prompt (str): Original prompt to upsample. + temperature (float): Temperature for generation (default: 0.01). + + Returns: + str: The upsampled prompt. + """ + dialogs = [ + [ + { + "role": "user", + "content": f"Upsample the short caption to a long caption: {input_prompt}", + } + ] + ] + + upsampled_prompt = run_chat_completion(prompt_upsampler_model, dialogs, temperature=temperature) + return upsampled_prompt + + +def print_rank_0(string: str): + rank = torch.distributed.get_rank() + if rank == 0: + log.info(string) + + +def process_prompt( + prompt: str, + checkpoint_dir: str, + prompt_upsampler_dir: str, + guardrails_dir: str, + image_path: str = None, + enable_prompt_upsampler: bool = True, +) -> str: + """ + Handle prompt upsampling if enabled, then run guardrails to ensure safety. + + Args: + prompt (str): The original text prompt. + checkpoint_dir (str): Base checkpoint directory. + prompt_upsampler_dir (str): Directory containing prompt upsampler weights. + guardrails_dir (str): Directory containing guardrails weights. + image_path (str, optional): Path to an image, if any (not implemented for upsampling). + enable_prompt_upsampler (bool): Whether to enable prompt upsampling. + + Returns: + str: The upsampled prompt or original prompt if upsampling is disabled or fails. + """ + + text_guardrail = create_text_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) + + # Check if the prompt is safe + is_safe = run_text_guardrail(str(prompt), text_guardrail) + if not is_safe: + raise ValueError("Guardrail blocked world generation.") + + if enable_prompt_upsampler: + if image_path: + raise NotImplementedError("Prompt upsampling is not supported for image generation") + else: + prompt_upsampler = create_prompt_upsampler( + checkpoint_dir=os.path.join(checkpoint_dir, prompt_upsampler_dir) + ) + upsampled_prompt = get_upsampled_prompt(prompt_upsampler, prompt) + print_rank_0(f"Original prompt: {prompt}\nUpsampled prompt: {upsampled_prompt}\n") + del prompt_upsampler + + # Re-check the upsampled prompt + is_safe = run_text_guardrail(str(upsampled_prompt), text_guardrail) + if not is_safe: + raise ValueError("Guardrail blocked world generation.") + + return upsampled_prompt + else: + return prompt + + +def save_video( + grid: np.ndarray, + fps: int, + H: int, + W: int, + video_save_quality: int, + video_save_path: str, + checkpoint_dir: str, + guardrails_dir: str, +): + """ + Save video frames to file, applying a safety check before writing. + + Args: + grid (np.ndarray): Video frames array [T, H, W, C]. + fps (int): Frames per second. + H (int): Frame height. + W (int): Frame width. + video_save_quality (int): Video encoding quality (0-10). + video_save_path (str): Output video file path. + checkpoint_dir (str): Directory containing model checkpoints. + guardrails_dir (str): Directory containing guardrails weights. + """ + video_classifier_guardrail = create_video_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) + + # Safety check on the entire video + grid = run_video_guardrail(grid, video_classifier_guardrail) + + kwargs = { + "fps": fps, + "quality": video_save_quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{W}x{H}"], + "output_params": ["-f", "mp4"], + } + + imageio.mimsave(video_save_path, grid, "mp4", **kwargs) diff --git a/cosmos1/models/diffusion/nemo/post_training/README.md b/cosmos1/models/diffusion/nemo/post_training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ba8c3dedbe03a072ed517bb06ceaf44138f36ac4 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/post_training/README.md @@ -0,0 +1,181 @@ +# Cosmos Diffusion-based World Foundation Models: NeMo Framework User Guide + +Learn how to [post-train](#post-train) Cosmos Diffusion-based World Foundation Models (WFMs) using the [NVIDIA NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html) for your custom Physical AI tasks by following this guide. + +## Model Support Matrix + +The NeMo Framework supports the following Cosmos Diffusion models. Review the available models and their compute requirements for post-tuning and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Post-Training | +|----------------------------------------------|------------------|------------------------------------------| +| Cosmos-1.0-Diffusion-7B-Text2World | **Supported** | 8 NVIDIA GPUs* | +| Cosmos-1.0-Diffusion-14B-Text2World | **Supported** | 8 NVIDIA GPUs* | +| Cosmos-1.0-Diffusion-7B-Video2World | **Coming Soon** | | +| Cosmos-1.0-Diffusion-14B-Video2WorldB | **Coming Soon** | | + + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +## Post-Training Support Matrix + +Cosmos Diffusion-based WFMs can be post-trained for a variety of Physical AI tasks. Review the following table for a list of available Physical AI post-training tasks: + +| Post-training Task | Post-Training Support Status | +|-------------------------|--------------------| +| General post-training | **Supported** | +| Instruction control | **Coming Soon** | +| Action control | **Coming Soon** | +| Camera control | **Coming Soon** | +| Multi-view generation | **Coming Soon** | +| Multi-view generation with vehicle trajectory control | **Coming Soon** | + +## Prerequisites + +### 1. Review General Requirements + +- System Configuration + - **NVIDIA GPU and driver**: Ensure you have access to the minimum compute required to run the model(s), as listed in the model support matrix. + - **Containerization Platform**: We recommend using Docker with NVIDIA Container Runtime (alternatively, you may use NVIDIA enroot). +- Get your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens), which is required to obtain the Cosmos models for training and inference. +- Get your [Weights and Biases API Key](https://docs.wandb.ai/support/find_api_key/) for logging and tracking. + +### 2. Clone the Cosmos Repository + +```bash +git clone git@github.com:NVIDIA/Cosmos.git +``` + +### 3. Start the Container + +The [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) supports post-training and inference for Cosmos Diffusion models. + +Run the following command to download and start the container: +```bash +docker run --ipc=host -it --gpus=all \ + -v $PATH_TO_COSMOS_REPO:/workspace/Cosmos \ + nvcr.io/nvidia/nemo:cosmos.1.0 bash +``` + +### 4. Download Checkpoints + +To help you get started, we've provided a [download script](../download_diffusion_nemo.py) to get the Cosmos Diffusion checkpoints from Hugging Face. These checkpoints are in the NeMo distributed checkpoint format required to run post-training and inference with NeMo Framework. + +1. Set the following environment variables: + ```bash + # You must set HF_HOME before running this script. + export HF_TOKEN="" + export HF_HOME="" + ``` +2. Run the following command to download the models: + ```bash + cd /workspace/Cosmos + python cosmos1/models/diffusion/nemo/download_diffusion_nemo.py + ``` + +## Post-train + +Post-training a Cosmos Diffusion-based WFM enables you to train the model to generate videos that are more specific to your Physical AI use case. + +For example, if you want to generate action sequences for a specific robot, you can post-train the model to generate videos that are more aligned with typical actions/outcomes for that robot. + +There are 3 steps to post-training: preparing a dataset, preprocessing the data, and post-training the model. + +### 1. Prepare a Dataset + +The first step is to prepare a dataset. Post-training a Cosmos-1.0-Diffusion-Text2World-{7B/14B}-NeMo model enables you to generate videos of a specific subject in new environments using a collection of input videos of that same subject as reference material. + +You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. These videos should focus on the subject throughout the entire video so that each video chunk contains the subject. + +Run the following command to download the sample videos used for post-training: + +```bash +huggingface-cli download nvidia/Cosmos-NeMo-Assets --repo-type dataset --local-dir cosmos1/models/diffusion/assets/ --include "*.mp4*" +``` + +### 2. Preprocess Data + +The second step is to preprocess the input videos. This generates the post-training samples and the metadata required for the post-training process by: + +1. Selecting `N` chunks of 121 frames from each video, generating `N` post-training samples per video. +2. Encoding the 121 frames by first independently compressing the first frame and then applying an 8x temporal compression for the rest of the frames. +3. Generating `total_samples = # of videos x # of chunks` post-training samples. + +Before proceeding, ensure all videos are in **RGB format**. Complete the following steps to generate the post-training samples and metadata for the robot dataset. Remember to follow the given prompt format by including the subject's name in the prompt. For example, if the subject is "robot," the prompt should read `"A video of sks robot."`. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Path to Raw mp4 videos. + export RAW_DATA="cosmos1/models/diffusion/assets/nemo_diffusion_example_data" + + # Path to Processed Dataset. + export CACHED_DATA="./cached_data" && mkdir -p $CACHED_DATA + ``` +2. Run the following command to preprocess the data: + ```bash + python cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py \ + --dataset_path $RAW_DATA \ + --output_path $CACHED_DATA \ + --prompt "A video of sks teal robot." \ + --num_chunks 500 + ``` + +Executing the [data preprocessing script](./prepare_dataset.py) generates the following files for each video (using `[i]` as the `index` of the video) at `$CACHED_DATA` path: + +- **`[i].info.json`**: Metadata for the video sample. +- **`[i].t5_text_embeddings.pth`**: T5-generated text embedding for the video clip. +- **`[i].t5_text_mask.pth`**: Mask for T5 text embedding, set to all ones by default to use the entire text embedding. +- **`[i].video_latent.pth`**: 3D spatiotemporal video tokens generated from the video tokenizer. + +### 3. Post-train the Model + +The third step is to post-train the model. This step uses NeMo Framework's data and model parallelism capabilities to train the model on the post-training samples. This is accomplished by using utilizing Fully Sharded Data Parallel (FSDP) and Tensor Parallelism. + +- **FSDP**: Distributes model parameters, optimizer states, and activations across all GPUs +- **Tensor Parallelism**: Spreads the parameter tensor of individual layers across GPUs. + +> **NOTE**: +> For the 14B model, we also employ activation checkpointing to facilitate single-node training. + +#### Run the Post-training Script + +Complete the following steps to post-train the Cosmos-1.0-Diffusion-7B-Text2World model on the robot dataset using 8 GPUs. + +1. Set the following environment variables: + ```bash + export HF_TOKEN="" + export HF_HOME="" + + # Optionally, you can monitor training progress with Weights and Biases (wandb). + export WANDB_API_KEY="" + export WANDB_PROJECT_NAME="cosmos-diffusion-nemo-post-training" + export WANDB_RUN_ID="cosmos_diffusion_7b_text2world_finetune" + ``` +2. Run the following command for Cosmos-Diffusion-Text2World-7B general post-training: + ```bash + NVTE_FUSED_ATTN=0 \ + CUDA_DEVICE_MAX_CONNECTIONS=1 \ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + torchrun --nproc_per_node=8 cosmos1/models/diffusion/nemo/post_training/general.py \ + --yes \ + --factory cosmos_diffusion_7b_text2world_finetune \ + data.path=$CACHED_DATA \ + trainer.max_steps=1000 \ + optim.config.lr=1e-6 + ``` +3. You can now run inference with your post-trained model using the instructions [here](../inference/README.md#run-the-inference-script-with-post-trained-model). + +#### Configuration Options + +Before getting started, review the following parameters made available to the script. You can adjust these parameters to optimize performance based on your specific requirements. + +| Parameter | Description | Default | +|--------------------------------|---------------------------------------------------------------------------------|---------| +| `--factory` | recipe to use cosmos_diffusion_7b_text2world_finetune or cosmos_diffusion_14b_text2world_finetune for general post-training | cosmos_diffusion_7b_text2world_finetune | +| `data.path` | Path to processed post-training dataset (str). | None | +| `resume.restore_config.path` | Path to pre-trained Cosmos Diffusion NeMo distributed checkpoint (str). | None | +| `optim.config.lr` | Learning rate (float). | 1e-6 | +| `trainer.max_steps` | Max number of post-training steps (int). | 1000 | +| `log.log_dir` | Path to folder to save post-training logs and checkpoints (str). | None | diff --git a/cosmos1/models/diffusion/nemo/post_training/general.py b/cosmos1/models/diffusion/nemo/post_training/general.py new file mode 100644 index 0000000000000000000000000000000000000000..13e97ec1b7754cb5e5188a1681c640812d7eef0c --- /dev/null +++ b/cosmos1/models/diffusion/nemo/post_training/general.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import nemo_run as run +from huggingface_hub import snapshot_download +from nemo.collections import llm +from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig +from nemo.collections.diffusion.train import pretrain, videofolder_datamodule +from nemo.lightning.pytorch.strategies.utils import RestoreConfig + + +@run.cli.factory(target=llm.train) +def cosmos_diffusion_7b_text2world_finetune() -> run.Partial: + # Model setup + recipe = pretrain() + recipe.model.config = run.Config(DiT7BConfig) + + # Trainer setup + recipe.trainer.max_steps = 1000 + recipe.optim.config.lr = 1e-6 + + # Tensor / Sequence parallelism + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.ckpt_async_save = False + + # FSDP + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = "MODEL_AND_OPTIMIZER_STATES" + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + + # Data setup + recipe.data = videofolder_datamodule() + recipe.data.path = "" # path to folder with processed dataset + + # Checkpoint load + recipe.resume.restore_config = run.Config(RestoreConfig, load_artifacts=False) + recipe.resume.restore_config.path = os.path.join( + snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", allow_patterns=["nemo/*"]), "nemo" + ) # path to diffusion model checkpoint + recipe.resume.resume_if_exists = False + + # Directory to save checkpoints / logs + recipe.log_log_dir = "nemo_experiments/cosmos_diffusion_7b_text2world_finetune" + + return recipe + + +@run.cli.factory(target=llm.train) +def cosmos_diffusion_14b_text2world_finetune() -> run.Partial: + # Model setup + recipe = pretrain() + recipe.model.config = run.Config(DiT14BConfig) + + # Trainer setup + recipe.trainer.max_steps = 1000 + recipe.optim.config.lr = 1e-6 + + # Tensor / Sequence parallelism + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.ckpt_async_save = False + + # FSDP + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = "MODEL_AND_OPTIMIZER_STATES" + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + + # Activation Checkpointing + recipe.model.config.recompute_granularity = "full" + recipe.model.config.recompute_method = "uniform" + recipe.model.config.recompute_num_layers = 1 + + # Data setup + recipe.data = videofolder_datamodule() + recipe.data.path = "" # path to folder with processed dataset + + # Checkpoint load + recipe.resume.restore_config = run.Config(RestoreConfig, load_artifacts=False) + recipe.resume.restore_config.path = os.path.join( + snapshot_download("nvidia/Cosmos-1.0-Diffusion-14B-Text2World", allow_patterns=["nemo/*"]), "nemo" + ) # path to diffusion model checkpoint + + recipe.resume.resume_if_exists = False + + # Directory to save checkpoints / logs + recipe.log_log_dir = "nemo_experiments/cosmos_diffusion_14b_text2world_finetune" + + return recipe + + +if __name__ == "__main__": + run.cli.main(llm.train, default_factory=cosmos_diffusion_7b_text2world_finetune) diff --git a/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb08cde550949b08444bd5191d87b52896c82e7 --- /dev/null +++ b/cosmos1/models/diffusion/nemo/post_training/prepare_dataset.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import json +import os +import random + +import torch +import torchvision +from einops import rearrange +from huggingface_hub import snapshot_download +from nemo.collections.diffusion.models.model import DiT7BConfig +from tqdm import tqdm +from transformers import T5EncoderModel, T5TokenizerFast + +from .log import log + + +def get_parser(): + parser = argparse.ArgumentParser(description="Process some configurations.") + parser.add_argument("--tokenizer_dir", type=str, default="", help="Path to the VAE model") + parser.add_argument( + "--dataset_path", type=str, default="video_dataset", help="Path to the dataset (a folder of videos)" + ) + parser.add_argument("--output_path", type=str, default="video_dataset_cached", help="Path to the output directory") + parser.add_argument("--prompt", type=str, default="a video of sks.", help="Prompt for the video") + parser.add_argument("--num_chunks", type=int, default=5, help="Number of random chunks to sample per video") + parser.add_argument("--height", type=int, default=704, help="Height to resize video") + parser.add_argument("--width", type=int, default=1280, help="Width to resize video") + return parser + + +def init_t5(): + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b") + text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b") + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +def init_video_tokenizer(tokenizer_dir: str): + """Initialize and return the Cosmos Video tokenizer.""" + dit_config = DiT7BConfig(vae_path=tokenizer_dir) + vae = dit_config.configure_vae() + return vae + + +@torch.no_grad() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512): + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + return encoded_text + + +def main(args): + # Set up output directory + os.makedirs(args.output_path, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5() + + # Initialize the VAE + if args.tokenizer_dir == "": + args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") + vae = init_video_tokenizer(args.tokenizer_dir) + + # Constants + t5_embeding_max_length = 512 + chunk_duration = vae.video_vae.pixel_chunk_duration # Frames per chunk + cnt = 0 # File index + + # Check if dataset_path is correct + files = glob.glob(os.path.join(args.dataset_path, "*.mp4")) + if not files: + raise ValueError(f"Dataset path {args.dataset_path} does not contain any .mp4 files.") + + # Process each video in the dataset folder + with torch.no_grad(): + for video_path in tqdm(glob.glob(os.path.join(args.dataset_path, "*.mp4"))): + # Read video (T x H x W x C) + video, _, meta = torchvision.io.read_video(video_path) + T, H, W, C = video.shape + + # Skip videos shorter than one chunk + if T < chunk_duration: + log.info(f"Video {video_path} is shorter than {chunk_duration} frames. Skipped.") + continue + + # Sample random segments + for _ in range(args.num_chunks): + start_idx = random.randint(0, T - chunk_duration) + chunk = video[start_idx : start_idx + chunk_duration] # (chunk_duration, H, W, C) + + # Rearrange dimensions: (T, H, W, C) -> (T, C, H, W) + chunk = rearrange(chunk, "t h w c -> t c h w") + + # Resize to [704, 1280] for each frame + chunk = torchvision.transforms.functional.resize(chunk, [args.height, args.width]) + + # Expand dims: (T, C, H, W) -> (B=1, C, T, H, W) + chunk = rearrange(chunk, "(b t) c h w -> b c t h w", b=1) + + # Convert to bf16 and normalize from [0, 255] to [-1, 1] + chunk = chunk.to(device="cuda", dtype=torch.bfloat16, non_blocking=True) / 127.5 - 1.0 + + # Encode video + latent = vae.encode(chunk).cpu() # shape: (1, latent_channels, T//factor, H//factor, W//factor) + + # Encode text + out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0] + encoded_text = torch.tensor(out, dtype=torch.bfloat16) + + # Pad T5 embedding to t5_embeding_max_length + L, C_ = encoded_text.shape + t5_embed = torch.zeros(1, t5_embeding_max_length, C_, dtype=torch.bfloat16) + t5_embed[0, :L] = encoded_text + + # Save data to folder + torch.save(latent[0], os.path.join(args.output_path, f"{cnt}.video_latent.pth")) + torch.save(t5_embed[0], os.path.join(args.output_path, f"{cnt}.t5_text_embeddings.pth")) + + # Create a T5 text mask of all ones + torch.save( + torch.ones(512, dtype=torch.bfloat16), os.path.join(args.output_path, f"{cnt}.t5_text_mask.pth") + ) + + # Save metadata + info = { + "height": H, + "width": W, + "fps": meta["video_fps"], + "num_frames": chunk_duration, + "video_path": os.path.basename(video_path), + "start_frame": start_idx, + } + with open(os.path.join(args.output_path, f"{cnt}.info.json"), "w") as json_file: + json.dump(info, json_file) + + cnt += 1 + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/cosmos1/models/diffusion/networks/__init__.py b/cosmos1/models/diffusion/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/guardrail/README.md b/cosmos1/models/guardrail/README.md new file mode 100644 index 0000000000000000000000000000000000000000..df170b5b9334eb545864b036444dc0fcca760868 --- /dev/null +++ b/cosmos1/models/guardrail/README.md @@ -0,0 +1,21 @@ +# Cosmos Guardrail + +This page outlines a set of tools to ensure content safety in Cosmos. For implementation details, please consult the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai). + +## Overview + +Our guardrail system consists of two stages: pre-Guard and post-Guard. + +Cosmos pre-Guard models are applied to text input, including input prompts and upsampled prompts. + +* Blocklist: a keyword list checker for detecting harmful keywords +* Aegis: an LLM-based approach for blocking harmful prompts + +Cosmos post-Guard models are applied to video frames generated by Cosmos models. + +* Video Content Safety Filter: a classifier trained to distinguish between safe and unsafe video frames +* Face Blur Filter: a face detection and blurring module + +## Usage + +Cosmos Guardrail models are integrated into the diffusion and autoregressive world generation pipelines in this repo. Check out the [Cosmos Diffusion Documentation](../diffusion/README.md) and [Cosmos Autoregressive Documentation](../autoregressive/README.md) to download the Cosmos Guardrail checkpoints and run the end-to-end demo scripts with our Guardrail models. diff --git a/cosmos1/models/guardrail/__init__.py b/cosmos1/models/guardrail/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/guardrail/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/guardrail/aegis/__init__.py b/cosmos1/models/guardrail/aegis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/guardrail/aegis/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/guardrail/blocklist/__init__.py b/cosmos1/models/guardrail/blocklist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/guardrail/blocklist/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/guardrail/common/__init__.py b/cosmos1/models/guardrail/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos1/models/guardrail/face_blur_filter/__init__.py b/cosmos1/models/guardrail/face_blur_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/guardrail/face_blur_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/models/guardrail/video_content_safety_filter/__init__.py b/cosmos1/models/guardrail/video_content_safety_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos1/models/guardrail/video_content_safety_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos1/scripts/format.sh b/cosmos1/scripts/format.sh new file mode 100644 index 0000000000000000000000000000000000000000..d65e804eabd86d32252ff50b63085f0a246804e0 --- /dev/null +++ b/cosmos1/scripts/format.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script to automatically format code and files in cosmos. +# It will temporarily activate a virtualenv with predefined versions of formatters/linters. + +cosmos_root=$(git rev-parse --show-toplevel) +venv_folder=$cosmos_root/.venv +scripts_folder=$cosmos_root/cosmos1/scripts + +echo "Formatting $cosmos_root" +if [ ! -d "$scripts_folder" ]; then + echo "script has to be called from repo root dir!" + exit -1 +fi + +if [ ! -d "$venv_folder" ]; then + mkdir -p $venv_folder + python3 -m pip install virtualenv + python3 -m venv $venv_folder +fi + +source $venv_folder/bin/activate + +dependencies=($(pip freeze | grep -E 'pre-commit==3.7.1|flake8==7.1.0|black==24.4.2|isort==5.13.2|loguru|termcolor')) +if [ "${#dependencies[@]}" -ne 6 ]; then + python3 -m pip install --upgrade pip + python3 -m pip install pre-commit==3.7.1 + python3 -m pip install flake8==7.1.0 + python3 -m pip install black==24.4.2 + python3 -m pip install isort==5.13.2 + python3 -m pip install loguru + python3 -m pip install termcolor +fi +set -e +python3 $scripts_folder/ip_header.py +pre-commit install-hooks +pre-commit run --all diff --git a/cosmos1/scripts/ip_header.py b/cosmos1/scripts/ip_header.py new file mode 100644 index 0000000000000000000000000000000000000000..1f702a17f00f3601e15cc894f478ed10daaf19ab --- /dev/null +++ b/cosmos1/scripts/ip_header.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +import termcolor + +parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") +parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") +args, files_to_check = parser.parse_known_args() + + +def get_header(ext: str = "py", old: str | bool = False) -> list[str]: + # This is the raw header. + # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # SPDX-License-Identifier: Apache-2.0 + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + header = [ + "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", + "SPDX-License-Identifier: Apache-2.0", + "", + 'Licensed under the Apache License, Version 2.0 (the "License");', + "you may not use this file except in compliance with the License.", + "You may obtain a copy of the License at", + "", + "http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing, software", + 'distributed under the License is distributed on an "AS IS" BASIS,', + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "See the License for the specific language governing permissions and", + "limitations under the License.", + ] + # Reformat according to different file extensions. + if ext == ".py" and old: + if old == "single": + header = ["'''"] + header + ["'''"] + elif old == "double": + header = ['"""'] + header + ['"""'] + else: + raise NotImplementedError + elif ext in (".py", ".yaml"): + header = [("# " + line if line else "#") for line in header] + elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): + header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] + else: + raise NotImplementedError + return header + + +def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: + if file.endswith("__init__.py"): + return + ext = os.path.splitext(file)[1] + # Read the file content (line by line). + content = open(file).read().splitlines() + # Check if cosmos header (with a blank newline) is properly embedded. + header = get_header(ext=ext) + if fix: + # If header passes format check, then just exit + if _check_header(content, header): + return + print(f"fixing: {file}") + # Remove old header if exists. + if ext == ".py": + for header_old in [ + get_header(ext=ext, old="single"), + get_header(ext=ext, old="double"), + ]: + if content[: len(header_old)] == header_old: + content = content[len(header_old) :] + # Clean up leading blank lines. + while len(content) > 0 and not content[0]: + content.pop(0) + # Add cosmos copyright header. + content = header + [""] + content + # Write content back to file. + with open(file, "w") as file_obj: + for line in content: + file_obj.write(line + "\n") + else: + if not _check_header(content, header): + bad_header = colorize("BAD HEADER", color="red", bold=True) + print(f"{bad_header}: {file}") + results[file] = 1 + else: + results[file] = 0 + + +def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: + # Apply/check the header for an entire directory. + files = os.listdir(path) + for file in files: + full_path = os.path.join(path, file) + if os.path.isdir(full_path): + # Traverse into the subdirectory. + traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) + elif os.path.isfile(full_path): + # Process the file. + ext = os.path.splitext(file)[1] + to_skip = False + for substr in substrings_to_skip: + if substr in full_path: + to_skip = True + break + + if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): + apply_file(full_path, results, fix=fix) + else: + raise NotImplementedError + + +def _check_header(content: list[str], header: list[str]) -> bool: + if content[: len(header)] != header: + return False + if len(content) > len(header): + if len(content) == len(header) + 1: + return False + if not (content[len(header)] == "" and content[len(header) + 1] != ""): + return False + return True + + +def colorize(x: str, color: str, bold: bool = False) -> str: + return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) # type: ignore + + +if __name__ == "__main__": + if not files_to_check: + # Default to the entire Cosmos repo. + files_to_check = [ + "cosmos1/utils", + "cosmos1/models", + "cosmos1/scripts", + ] + + # Check whether all input files/directories are valid. + for file in files_to_check: + assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" + + substrings_to_skip = ["prompt_upsampler"] + # Run the program. + results = dict() + for file in files_to_check: + if os.path.isfile(file): + apply_file(file, results, fix=args.fix) + elif os.path.isdir(file): + traverse_directory(file, results, fix=args.fix, substrings_to_skip=["prompt_upsampler"]) + else: + raise NotImplementedError + + if any(results.values()): + sys.exit(1) diff --git a/cosmos1diffusiontext2world.py b/cosmos1diffusiontext2world.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd066a54e6e302aae82794b13e7cc271a331700 --- /dev/null +++ b/cosmos1diffusiontext2world.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from .lazy_config_init import LazyDict + +Cosmos_1_0_Diffusion_Text2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_1_0_Diffusion_Text2World_7B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + ) +) + + +Cosmos_1_0_Diffusion_Text2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_1_0_Diffusion_Text2World_14B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + ) +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_1_0_Diffusion_Text2World_7B["job"]["name"], + node=Cosmos_1_0_Diffusion_Text2World_7B, +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_1_0_Diffusion_Text2World_14B["job"]["name"], + node=Cosmos_1_0_Diffusion_Text2World_14B, +) diff --git a/cosmos1diffusionvideo2world.py b/cosmos1diffusionvideo2world.py new file mode 100644 index 0000000000000000000000000000000000000000..641a47c0c875b44032ed40afd9d03617c4f19aec --- /dev/null +++ b/cosmos1diffusionvideo2world.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from .general_dit_video_conditioned import VideoExtendGeneralDIT +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict + +Cosmos_1_0_Diffusion_Video2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + job=dict(group="Video2World", name="Cosmos_1_0_Diffusion_Video2World_7B"), + ) +) + + +Cosmos_1_0_Diffusion_Video2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + job=dict(group="Video2World", name="Cosmos_1_0_Diffusion_Video2World_14B"), + ) +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_1_0_Diffusion_Video2World_7B["job"]["name"], + node=Cosmos_1_0_Diffusion_Video2World_7B, +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_1_0_Diffusion_Video2World_14B["job"]["name"], + node=Cosmos_1_0_Diffusion_Video2World_14B, +) diff --git a/denoiser_scaling.py b/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fb3df0f38d52de317177c248e22707a899beb4 --- /dev/null +++ b/denoiser_scaling.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise diff --git a/device.py b/device.py new file mode 100644 index 0000000000000000000000000000000000000000..db486afabd4ae0bf11feb05d8a4efd96690ce64b --- /dev/null +++ b/device.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os + +import pynvml + + +class Device: + """A class to handle NVIDIA GPU device operations using NVML. + + This class provides an interface to access and manage NVIDIA GPU devices, + including retrieving device information and CPU affinity settings. + + Attributes: + _nvml_affinity_elements (int): Number of 64-bit elements needed to represent CPU affinity + """ + + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore + + def __init__(self, device_idx: int): + """Initialize a Device instance for a specific GPU. + + Args: + device_idx (int): Index of the GPU device to manage + + Raises: + NVMLError: If the device cannot be found or initialized + """ + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def get_cpu_affinity(self) -> list[int]: + """Get the CPU affinity mask for this GPU device. + + Retrieves the CPU affinity mask indicating which CPU cores are assigned + to this GPU device. The affinity is returned as a list of CPU core indices. + + Returns: + list[int]: List of CPU core indices that have affinity with this GPU + + Raises: + NVMLError: If the CPU affinity information cannot be retrieved + + Example: + >>> device = Device(0) + >>> device.get_cpu_affinity() + [0, 1, 2, 3] # Shows this GPU has affinity with CPU cores 0-3 + """ + affinity_string = "" + for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): + # assume nvml returns list of 64 bit ints + affinity_string = "{:064b}".format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + return [i for i, e in enumerate(affinity_list) if e != 0] diff --git a/df_base_model.py b/df_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7b41f7fb8cd4cf73d89b3f6d550dfc2d19fbe254 --- /dev/null +++ b/df_base_model.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import attrs + +from .lazy_config_init import LazyDict + + +@attrs.define(slots=False) +class DefaultModelConfig: + tokenizer: LazyDict = None + conditioner: LazyDict = None + net: LazyDict = None + sigma_data: float = 0.5 + precision: str = "bfloat16" + input_data_key: str = "video" # key to fetch input data from data_batch + latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + + +@attrs.define(slots=False) +class LatentDiffusionDecoderModelConfig(DefaultModelConfig): + tokenizer_corruptor: LazyDict = None + latent_corruptor: LazyDict = None + pixel_corruptor: LazyDict = None + diffusion_decoder_cond_sigma_low: float = None + diffusion_decoder_cond_sigma_high: float = None + diffusion_decoder_corrupt_prob: float = None + condition_on_tokenizer_corruptor_token: bool = False diff --git a/df_config_base_net.py b/df_config_base_net.py new file mode 100644 index 0000000000000000000000000000000000000000..99141c1708e65db996af873efc9f425a70e2e498 --- /dev/null +++ b/df_config_base_net.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from .general_dit import GeneralDIT +from .lazy_config_init import LazyCall as L +from .lazy_config_init import LazyDict + +FADITV2Config: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, +) + + +FADITV2_14B_Config = copy.deepcopy(FADITV2Config) +FADITV2_14B_Config.model_channels = 5120 +FADITV2_14B_Config.num_heads = 40 +FADITV2_14B_Config.num_blocks = 36 diff --git a/df_config_base_tokenizer.py b/df_config_base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaf7d74e3eb6d31c2d91b4d61ab2f98c236480c --- /dev/null +++ b/df_config_base_tokenizer.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from .pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from .lazy_config_init import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + ) diff --git a/df_config_config.py b/df_config_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3c7e6851dbdd91606f6c1b91aa0ab1d2e9e2c9 --- /dev/null +++ b/df_config_config.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from .df_base_model import DefaultModelConfig +from .df_config_registry import register_configs +from .config import Config as ori_Config +from .config_helper import import_all_modules_from_package + +# import config here, not use importlib +from .cosmos1diffusiontext2world import LazyDict +from .cosmos1diffusionvideo2world import LazyDict + +@attrs.define(slots=False) +class Config(ori_Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"tokenizer": "tokenizer"}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig(), + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_diffusion" + c.job.group = "inference" + + # Call this function to register config groups for advanced overriding. + register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True) + return c diff --git a/df_config_registry.py b/df_config_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..308658b8793a690c9fdbd451e4221653374e28a4 --- /dev/null +++ b/df_config_registry.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from .config_base_conditioner import ( + BaseVideoConditionerConfig, + VideoConditionerFpsSizePaddingConfig, + VideoExtendConditionerConfig, +) +from .df_config_base_net import FADITV2_14B_Config, FADITV2Config +from .df_config_base_tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 + + +def register_net(cs): + cs.store( + group="net", + package="model.net", + name="faditv2_7b", + node=FADITV2Config, + ) + cs.store( + group="net", + package="model.net", + name="faditv2_14b", + node=FADITV2_14B_Config, + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="basic", + node=BaseVideoConditionerConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond", + node=VideoExtendConditionerConfig, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_net(cs) + register_conditioner(cs) + register_tokenizer(cs) diff --git a/diffusion_types.py b/diffusion_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a209db0eba28a8d8bcb527bfbaca6f5e361ace14 --- /dev/null +++ b/diffusion_types.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty diff --git a/discrete_video.py b/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5a5244c87516121f3e7686c924f8b1c66cd772 --- /dev/null +++ b/discrete_video.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange + +from .ar_tokenizer_quantizers import FSQuantizer + +# Make sure jit model output consistenly during consecutive calls +# Check here: https://github.com/pytorch/pytorch/issues/74534 +torch._C._jit_set_texpr_fuser_enabled(False) + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + model = torch.jit.load(jit_filepath) + return model.eval().to(device) + + +class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): + """ + A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. + + Attributes: + encoder (Module | Callable): Encoder loaded from storage. + decoder (Module | Callable): Decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__() + self.channel = latent_ch + self.name = name + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.pixel_chunk_duration = pixel_chunk_duration + self.latent_chunk_duration = latent_chunk_duration + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + self.levels = levels + self.compress_ratio = compression_ratio + self.fsq_quantizer = FSQuantizer(levels) + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the tokenizer. + """ + return self.channel + + @torch.no_grad() + def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, C, T, H, W = state.shape + if pixel_chunk_duration is None: + # Use the default pixel chunk duration and latent chunk duration + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + # Update the latent chunk duration based on the given pixel chunk duration + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + + assert ( + T % pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" + state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) + + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + quantized_out_list = [] + indices_list = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) + quantized_out_list.append(quantized_out) + indices_list.append(indices) + quantized_out = torch.cat(quantized_out_list, dim=0) + indices = torch.cat(indices_list, dim=0) + else: + indices, quantized_out, _ = self.encoder(state.to(self.dtype)) + assert quantized_out.shape[2] == latent_chunk_duration + return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( + indices, "(b n) t h w -> b (n t) h w", b=B + ) + + @torch.no_grad() + def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, T, _, _ = indices.shape + if pixel_chunk_duration is None: + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + assert ( + T % latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" + indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) + + # use max_dec_batch_size to avoid OOM + if indices.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, indices.shape[0], self.max_dec_batch_size): + state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = self.decoder(indices) + + assert state.shape[2] == pixel_chunk_duration + return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = load_jit_model(enc_fp, device="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = load_jit_model(dec_fp, device="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder + into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, + handles data type conversions, and normalization using provided mean and standard deviation values for latent + space representation. + + Attributes: + tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints + encoder (Callable): tokenizer_module's encode method + decoder (Callable): tokenizer_module's decode method + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + tokenizer_module (Module): Tokenizer module that will have it's weights loaded + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + tokenizer_module: torch.nn.Module, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) + + def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + - def_fp (str): File path to the decoder's JIT file on the remote store. + - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints + """ + self.decoder = load_jit_model(dec_fp) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + encoder_sd = load_jit_model(enc_fp).state_dict() + + del tokenizer_module.post_quant_conv + del tokenizer_module.decoder + + state_dict = { + k: v + for k, v in (encoder_sd).items() + # Variables captured by JIT + if k + not in ( + "encoder.patcher3d.wavelets", + "encoder.patcher3d._arange", + "encoder.patcher3d.patch_size_buffer", + "quantizer._levels", + "quantizer._basis", + "quantizer.implicit_codebook", + ) + } + + tokenizer_module.load_state_dict(state_dict) + + tokenizer_module.eval() + for param in tokenizer_module.parameters(): + param.requires_grad = False + tokenizer_module.to(self.dtype) + + self.tokenizer_module = tokenizer_module + self.encoder = self.tokenizer_module.encode + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.tokenizer_module.to(self.dtype) diff --git a/distributed.py b/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..69f477ced9dfe59deda742bc507addf7d7268bdf --- /dev/null +++ b/distributed.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import ctypes +import functools +import os +from datetime import timedelta +from typing import Any, Callable, Optional + +import pynvml +import torch +import torch.distributed as dist + +from .log import log +from .device import Device + + +def init() -> int | None: + """Initialize distributed training.""" + # Set GPU affinity. + pynvml.nvmlInit() + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = Device(local_rank) + os.sched_setaffinity(0, device.get_cpu_affinity()) + # Set up NCCL communication. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + if dist.is_available(): + if dist.is_initialized(): + return torch.cuda.current_device() + torch.cuda.set_device(local_rank) + # Get the timeout value from environment variable + timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) + # Convert the timeout to an integer (if it isn't already) and then to a timedelta + timeout_timedelta = timedelta(seconds=int(timeout_seconds)) + dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) + log.critical( + f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", + rank0_only=False, + ) + # Increase the L2 fetch granularity for faster speed. + _libcudart = ctypes.CDLL("libcudart.so") + # Set device limit on the current device. + p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) + log.info(f"Training with {get_world_size()} GPUs.") + + +def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: + """Get world size. How many GPUs are available in this job. + + Returns: + world_size (int): The total number of GPUs available in this job. + """ + world_size = 1 + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size(group) + return world_size + + +def is_rank0() -> bool: + """Check if current process is the master GPU. + + Returns: + (bool): True if this function is called from the master GPU, else False. + """ + return get_rank() == 0 + + +def rank0_only(func: Callable) -> Callable: + """Apply this function only to the master GPU. + + Example usage: + @rank0_only + def func(x): + return x + 3 + + Args: + func (Callable): a function. + + Returns: + (Callable): A function wrapper executing the function only on the master GPU. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + return func(*args, **kwargs) + else: + return None + + return wrapper + + +def barrier() -> None: + """Barrier for all GPUs.""" + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): + """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). + + This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that + model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling + model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> + training_step), allowing us to preserve the function names and signatures. + """ + + def __init__(self, model: torch.nn.Module, *args, **kwargs): + super().__init__(model, *args, **kwargs) + + def training_step(self, *args, **kwargs) -> Any: + # Cache the original model.forward() method. + original_forward = self.module.forward + + def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 + # Unpatch immediately before calling training_step() because itself may want to call the real forward. + self.module.forward = original_forward + # The actual .training_step(). + return self.module.training_step(*_args, **_kwargs) + + # Patch the original_module's forward so we can redirect the arguments back to the real method. + self.module.forward = wrapped_training_step + # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). + # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. + return self(*args, **kwargs) + + +def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: + """Aggregate the list of data batches from all devices and process the results. + + This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. + It will return the data/output of the entire validation set in its original index order. The sizes of data_batches + in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be + created before calling dis.all_gather(). + + Args: + data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where + leaf entries are tensors. + + Returns: + data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where + leaf entries are concatenated tensors. + """ + if isinstance(data_batches[0], torch.Tensor): + # Concatenate the local data batches. + data_concat = torch.cat(data_batches, dim=0) # type: ignore + # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. + max_num_local_samples = torch.tensor(len(data_concat), device="cuda") + dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) + if len(data_concat) < max_num_local_samples: + assert len(data_concat) + 1 == max_num_local_samples + dummy = torch.empty_like(data_concat[:1]) + data_concat = torch.cat([data_concat, dummy], dim=0) + dummy_count = torch.tensor(1, device="cuda") + else: + dummy_count = torch.tensor(0, device="cuda") + # Get all concatenated batches from all ranks and concatenate again. + dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) + data_concat = all_gather_tensor(data_concat.contiguous()) + data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) + # Remove the dummy samples. + if dummy_count > 0: + data_collate = data_collate[:-dummy_count] + elif isinstance(data_batches[0], collections.abc.Mapping): + data_collate = dict() + for key in data_batches[0].keys(): + data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore + else: + raise TypeError + return data_collate + + +@torch.no_grad() +def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: + """Gather the corresponding tensor from all GPU devices to a list. + + Args: + tensor (torch.Tensor): Pytorch tensor. + + Returns: + tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. + """ + tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] + dist.all_gather(tensor_list, tensor) + return tensor_list + + +def broadcast(tensor, src, group=None, async_op=False): + world_size = get_world_size() + if world_size < 2: + return tensor + dist.broadcast(tensor, src=src, group=group, async_op=async_op) diff --git a/download_autoregressive.py b/download_autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..b08ba7bcabddf9f9e636a1ec65a56bc53e160c7d --- /dev/null +++ b/download_autoregressive.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from huggingface_hub import snapshot_download + + +def parse_args(): + parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Autoregressive models from Hugging Face") + parser.add_argument( + "--model_sizes", + nargs="*", + default=[ + "4B", + "5B", + ], # Download all by default + choices=["4B", "5B", "12B", "13B"], + help="Which model sizes to download. Possible values: 4B, 5B, 12B, 13B.", + ) + parser.add_argument( + "--cosmos_version", + type=str, + default="1.0", + choices=["1.0"], + help="Which version of Cosmos to download. Only 1.0 is available at the moment.", + ) + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints." + ) + args = parser.parse_args() + return args + + +def main(args): + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "4B": "Cosmos-1.0-Autoregressive-4B", + "5B": "Cosmos-1.0-Autoregressive-5B-Video2World", + "12B": "Cosmos-1.0-Autoregressive-12B", + "13B": "Cosmos-1.0-Autoregressive-13B-Video2World", + } + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-1.0-Guardrail", + "Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8", + "Cosmos-1.0-Tokenizer-CV8x8x8", + "Cosmos-1.0-Tokenizer-DV8x16x16", + ] + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"]) + + # Download the requested Autoregressive models + for size in args.model_sizes: + model_name = model_map[size] + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + local_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + **download_kwargs, + ) + + # Download the always-included models + for model_name in extra_models: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + local_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/download_diffusion.py b/download_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddbd7412619f0ca3f922900a99e9c10f693117e --- /dev/null +++ b/download_diffusion.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from huggingface_hub import snapshot_download + +from .convert_pixtral_ckpt import convert_pixtral_checkpoint + + +def main(model_types, model_sizes, checkpoint_dir="checkpoints"): + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "7B": "Cosmos-1.0-Diffusion-7B", + "14B": "Cosmos-1.0-Diffusion-14B", + } + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-1.0-Guardrail", + "Cosmos-1.0-Tokenizer-CV8x8x8", + ] + + if "Text2World" in model_types: + extra_models.append("Cosmos-1.0-Prompt-Upsampler-12B-Text2World") + + # Create local checkpoints folder + checkpoints_dir = Path(checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"]) + + # Download the requested Autoregressive models + for size in model_sizes: + for model_type in model_types: + suffix = f"-{model_type}" + model_name = model_map[size] + suffix + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + local_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + # Download the always-included models + for model_name in extra_models: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + local_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files for Guardrail + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + if "Video2World" in model_types: + # Prompt Upsampler for Cosmos-1.0-Diffusion-Video2World models + convert_pixtral_checkpoint( + checkpoint_dir=checkpoint_dir, + checkpoint_name="Pixtral-12B", + vit_type="pixtral-12b-vit", + ) + diff --git a/face_blur_filter.py b/face_blur_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..89af1fbe4b2c6e6d2c1463bad1498918021ac82c --- /dev/null +++ b/face_blur_filter.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from .log import log +import numpy as np +import torch +from pytorch_retinaface.data import cfg_re50 +from pytorch_retinaface.layers.functions.prior_box import PriorBox +from pytorch_retinaface.models.retinaface import RetinaFace +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from .guardrail_core import GuardrailRunner, PostprocessingGuardrail +from .guardrail_io_utils import get_video_filepaths, read_video, save_video +from .blur_utils import pixelate_face +from .retinaface_utils import decode_batch, filter_detected_boxes, load_model +from .misc import misc, Color, timer + +DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth" + +# RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +TOP_K = 5_000 +KEEP_TOP_K = 750 +NMS_THRESHOLD = 0.4 + + +class RetinaFaceFilter(PostprocessingGuardrail): + def __init__( + self, + checkpoint: str = DEFAULT_RETINAFACE_CHECKPOINT, + batch_size: int = 1, + confidence_threshold: float = 0.7, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + """ + Initialize the RetinaFace model for face detection and blurring. + + Args: + checkpoint: Path to the RetinaFace checkpoint file + batch_size: Batch size for RetinaFace inference and processing + confidence_threshold: Minimum confidence score to consider a face detection + """ + self.cfg = cfg_re50 + self.batch_size = batch_size + self.confidence_threshold = confidence_threshold + self.device = device + self.dtype = torch.float32 + + # Disable loading ResNet pretrained weights + self.cfg["pretrain"] = False + self.net = RetinaFace(cfg=self.cfg, phase="test") + cpu = self.device == "cpu" + + # Load from RetinaFace pretrained checkpoint + self.net = load_model(self.net, checkpoint, cpu) + self.net.to(self.device, dtype=self.dtype).eval() + + def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor: + """Preprocess a sequence of frames for face detection. + + Args: + frames: Input frames + + Returns: + Preprocessed frames tensor + """ + with torch.no_grad(): + frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C] + frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W] + frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input + means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1) + frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel + return frames_tensor + + def blur_detected_faces( + self, + frames: np.ndarray, + batch_loc: torch.Tensor, + batch_conf: torch.Tensor, + prior_data: torch.Tensor, + scale: torch.Tensor, + min_size: tuple[int] = (20, 20), + ) -> list[np.ndarray]: + """Blur detected faces in a batch of frames using RetinaFace predictions. + + Args: + frames: Input frames + batch_loc: Batched location predictions + batch_conf: Batched confidence scores + prior_data: Prior boxes for the video + scale: Scale factor for resizing detections + min_size: Minimum size of a detected face region in pixels + + Returns: + Processed frames with pixelated faces + """ + with torch.no_grad(): + batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"]) + batch_boxes = batch_boxes * scale + + blurred_frames = [] + for i, boxes in enumerate(batch_boxes): + boxes = boxes.detach().cpu().numpy() + scores = batch_conf[i, :, 1].detach().cpu().numpy() + + filtered_boxes = filter_detected_boxes( + boxes, + scores, + confidence_threshold=self.confidence_threshold, + nms_threshold=NMS_THRESHOLD, + top_k=TOP_K, + keep_top_k=KEEP_TOP_K, + ) + + frame = frames[i] + for box in filtered_boxes: + x1, y1, x2, y2 = map(int, box) + # Ignore bounding boxes smaller than the minimum size + if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]: + continue + max_h, max_w = frame.shape[:2] + face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] + blurred_face = pixelate_face(face_roi) + frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face + blurred_frames.append(frame) + + return blurred_frames + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Blur faces in a sequence of frames. + + Args: + frames: Input frames + + Returns: + Processed frames with pixelated faces + """ + # Create dataset and dataloader + frames_tensor = self.preprocess_frames(frames) + dataset = TensorDataset(frames_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + processed_frames, processed_batches = [], [] + + prior_data, scale = None, None + for i, batch in enumerate(dataloader): + batch = batch[0] + h, w = batch.shape[-2:] # Batch shape: [C, H, W] + + with torch.no_grad(): + # Generate priors for the video + if prior_data is None: + priorbox = PriorBox(self.cfg, image_size=(h, w)) + priors = priorbox.forward() + priors = priors.to(self.device, dtype=self.dtype) + prior_data = priors.data + + # Get scale for resizing detections + if scale is None: + scale = torch.Tensor([w, h, w, h]) + scale = scale.to(self.device, dtype=self.dtype) + + batch_loc, batch_conf, _ = self.net(batch) + + # Blur detected faces in each batch of frames + start_idx = i * self.batch_size + end_idx = min(start_idx + self.batch_size, len(frames)) + processed_batches.append( + self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale) + ) + + processed_frames = [frame for batch in processed_batches for frame in batch] + return np.array(processed_frames) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos") + parser.add_argument( + "--checkpoint", + type=str, + help="Path to the RetinaFace checkpoint file", + default=DEFAULT_RETINAFACE_CHECKPOINT, + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + face_blur = RetinaFaceFilter(checkpoint=args.checkpoint) + postprocessing_runner = GuardrailRunner(postprocessors=[face_blur]) + os.makedirs(args.output_dir, exist_ok=True) + + for filepath in tqdm(filepaths): + video_data = read_video(filepath) + with timer("face blur filter"): + frames = postprocessing_runner.postprocess(video_data.frames) + + output_path = os.path.join(args.output_dir, os.path.basename(filepath)) + save_video(output_path, frames, video_data.fps) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/file_io.py b/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..d9caf0081976dd08ab6ea1c04ad53304bc51d05d --- /dev/null +++ b/file_io.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler +from iopath.common.file_io import PathManager as PathManagerBase + +__all__ = ["PathManager", "PathHandler"] + + +PathManager = PathManagerBase() +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) diff --git a/general_dit.py b/general_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..eafd288d441442dd5e04824ccbcbd29291a38ca1 --- /dev/null +++ b/general_dit.py @@ -0,0 +1,520 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +""" + +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from .conditioner import DataType +from .attention import get_normalization +from .blocks import ( + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + TimestepEmbedding, + Timesteps, +) +from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from .log import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block. See Notes for supported block types. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + use_cross_attn_mask (bool): Whether to use mask in cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + affline_emb_norm (bool): Whether to normalize affine embeddings. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + + Notes: + Supported block types in block_config: + * cross_attn, ca: Cross attention + * full_attn: Full attention on all flattened tokens + * mlp, ff: Feed forward block + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.affline_emb_norm = affline_emb_norm + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + + for idx in range(num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + + self.build_decode_head() + if self.affline_emb_norm: + log.debug("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to + augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + for _, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + return x_B_D_T_H_W diff --git a/general_dit_video_conditioned.py b/general_dit_video_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..15fbc5d71ae8b8deb8aa8d1921a0632c8e7d689f --- /dev/null +++ b/general_dit_video_conditioned.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from torch import nn + +from .conditioner import DataType +from .blocks import TimestepEmbedding, Timesteps +from .general_dit import GeneralDIT +from .log import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + Timesteps(self.model_channels), + TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def initialize_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().initialize_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass of the video-conditioned DIT model. + + Args: + x: Input tensor of shape (B, C, T, H, W) + timesteps: Timestep tensor of shape (B,) + crossattn_emb: Cross attention embeddings of shape (B, N, D) + crossattn_mask: Optional cross attention mask of shape (B, N) + fps: Optional frames per second tensor + image_size: Optional image size tensor + padding_mask: Optional padding mask tensor + scalar_feature: Optional scalar features tensor + data_type: Type of data being processed (default: DataType.VIDEO) + video_cond_bool: Optional video conditioning boolean tensor + condition_video_indicator: Optional video condition indicator tensor + condition_video_input_mask: Required mask tensor for video data type + condition_video_augment_sigma: Optional sigma values for conditional input augmentation + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Output tensor + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" + + input_list = [x, condition_video_input_mask] + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log.info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log.info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log.info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/guardrail_blocklist_utils.py b/guardrail_blocklist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..859eb6498143e5b063dbc888dca7748a07cfda9d --- /dev/null +++ b/guardrail_blocklist_utils.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from .log import log + + +def read_keyword_list_from_dir(folder_path: str) -> list[str]: + """Read keyword list from all files in a folder.""" + output_list = [] + file_list = [] + # Get list of files in the folder + for file in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, file)): + file_list.append(file) + + # Process each file + for file in file_list: + file_path = os.path.join(folder_path, file) + try: + with open(file_path, "r") as f: + output_list.extend([line.strip() for line in f.readlines()]) + except Exception as e: + log.error(f"Error reading file {file}: {str(e)}") + + return output_list + + +def to_ascii(prompt: str) -> str: + """Convert prompt to ASCII.""" + return re.sub(r"[^\x00-\x7F]+", " ", prompt) diff --git a/guardrail_core.py b/guardrail_core.py new file mode 100644 index 0000000000000000000000000000000000000000..e4916c3379353f577a811def9a1d29f2e0a48708 --- /dev/null +++ b/guardrail_core.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +import numpy as np + +from .log import log + + +class ContentSafetyGuardrail: + def is_safe(self, **kwargs) -> Tuple[bool, str]: + raise NotImplementedError("Child classes must implement the is_safe method") + + +class PostprocessingGuardrail: + def postprocess(self, frames: np.ndarray) -> np.ndarray: + raise NotImplementedError("Child classes must implement the postprocess method") + + +class GuardrailRunner: + def __init__( + self, + safety_models: list[ContentSafetyGuardrail] | None = None, + generic_block_msg: str = "", + generic_safe_msg: str = "", + postprocessors: list[PostprocessingGuardrail] | None = None, + ): + self.safety_models = safety_models + self.generic_block_msg = generic_block_msg + self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe" + self.postprocessors = postprocessors + + def run_safety_check(self, input: Any) -> Tuple[bool, str]: + """Run the safety check on the input.""" + if not self.safety_models: + log.warning("No safety models found, returning safe") + return True, self.generic_safe_msg + + for guardrail in self.safety_models: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + safe, message = guardrail.is_safe(input) + if not safe: + reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}" + return False, reasoning + return True, self.generic_safe_msg + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Run the postprocessing on the video frames.""" + if not self.postprocessors: + log.warning("No postprocessors found, returning original frames") + return frames + + for guardrail in self.postprocessors: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + frames = guardrail.postprocess(frames) + return frames diff --git a/guardrail_io_utils.py b/guardrail_io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..148897d5cae9165673cb74e336548c71adb261b1 --- /dev/null +++ b/guardrail_io_utils.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +from dataclasses import dataclass + +import imageio +import numpy as np + +from .log import log + + +@dataclass +class VideoData: + frames: np.ndarray # Shape: [B, H, W, C] + fps: int + duration: int # in seconds + + +def get_video_filepaths(input_dir: str) -> list[str]: + """Get a list of filepaths for all videos in the input directory.""" + paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True) + paths = sorted(paths) + log.debug(f"Found {len(paths)} videos") + return paths + + +def read_video(filepath: str) -> VideoData: + """Read a video file and extract its frames and metadata.""" + try: + reader = imageio.get_reader(filepath, "ffmpeg") + except Exception as e: + raise ValueError(f"Failed to read video file: {filepath}") from e + + # Extract metadata from the video file + try: + metadata = reader.get_meta_data() + fps = metadata.get("fps") + duration = metadata.get("duration") + except Exception as e: + reader.close() + raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e + + # Extract frames from the video file + try: + frames = np.array([frame for frame in reader]) + except Exception as e: + raise ValueError(f"Failed to extract frames from video file: {filepath}") from e + finally: + reader.close() + + return VideoData(frames=frames, fps=fps, duration=duration) + + +def save_video(filepath: str, frames: np.ndarray, fps: int) -> None: + """Save a video file from a sequence of frames.""" + try: + writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1) + for frame in frames: + writer.append_data(frame) + except Exception as e: + raise ValueError(f"Failed to save video file to {filepath}") from e + finally: + writer.close() diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4252b3758a2d49c5809dd963f7ae403209cbff7b --- /dev/null +++ b/inference.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, TypedDict + +import torch + +from .ar_model import AutoRegressiveModel +from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from .ar_tokenizer_text_tokenizer import TextTokenizer + + +class ChatPrediction(TypedDict, total=False): + tokens: List[str] # not required + logprobs: List[float] # not required + + +def chat_completion( + model: AutoRegressiveModel, + dialogs: List, + seed: int = None, + temperature: float = 0.01, + top_k: int = None, + top_p: float = None, + max_gen_len: Optional[int] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + generation_prefix: str = "", + compile_sampling: bool = False, + compile_prefill: bool = False, + stop_tokens=None, + verbose: bool = False, +) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + model (AutoRegressiveModel): The language generation model. + dialogs (List): List of conversational dialogs, where each dialog is a list of messages. + NOTE if you are using a VLM, all dialogs must either all have images ("image" field) or all be pure text. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.01. + top_k (int, optional): Top-k probability threshold for nucleus sampling. Defaults to None. If not None, top-p sampling is ignored. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. If not None, top-k sampling is ignored. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + num_gen_seq (int, optional): Number of sequences to generate per prompt. Defaults to 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + generation_prefix (str, optional): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + compile_sampling (bool, optional): Flag indicating whether to compile the generation function. Defaults to False. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + stop_tokens (Set[int], optional): Set of tokens to stop generation. Defaults to None. If not None, it will override the model's stop tokens. + verbose (bool, optional): Flag indicating whether to print the generation throughput. Defaults to False. + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + """ + if max_gen_len is None: + max_gen_len = model.model.params.max_seq_len - 1 + images = None + if isinstance(model.tokenizer.text_tokenizer, ImageTextTokenizer): + # Vision-language model + prompt_dicts = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + prompt_tokens = [prompt_dict["input_ids"] for prompt_dict in prompt_dicts] + num_images = sum(["pixel_values" in prompt_dict for prompt_dict in prompt_dicts]) + assert num_images in [0, len(dialogs)], "For VLM, all dialogs must either all have images or all be pure text." + if num_images > 0: + images = torch.cat([prompt_dict["pixel_values"] for prompt_dict in prompt_dicts], dim=0) + else: + images = None + elif isinstance(model.tokenizer.text_tokenizer, TextTokenizer): + # Text-only model + prompt_tokens = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + else: + prompt_tokens = [model.formatter.encode_dialog_prompt(dialog) for dialog in dialogs] + + generation_tokens, generation_logprobs = model.generate( + prompt_tokens=prompt_tokens, + seed=seed, + max_gen_len=max_gen_len, + num_gen_seq=num_gen_seq, + temperature=temperature, + top_k=top_k, + top_p=top_p, + compile_sampling=compile_sampling, + compile_prefill=compile_prefill, + stop_tokens=stop_tokens, + verbose=verbose, + images=images, + ) + + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + "tokens": [model.tokenizer.text_tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + } + for t in generation_tokens + ] diff --git a/inference_config.py b/inference_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea7782d3a7217d3c5eaba9d696b0a6dc3f836ed --- /dev/null +++ b/inference_config.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Union + +import attrs + +from .ar_configs_base_model import ModelConfig, TokenizerConfig + + +@attrs.define(slots=False) +class DataShapeConfig: + latent_shape: list = [] + num_video_frames: Union[None, int] = None + height: Union[None, int] = None + width: Union[None, int] = None + + +@attrs.define(slots=False) +class SamplingConfig: + """ + Sampling config + Args: + temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + """ + + temperature: float = 0.6 + top_k: int = None + top_p: float = 0.9 + compile_prefill: bool = False + compile_sampling: bool = True + logprobs: bool = False + echo: bool = False + + +@attrs.define(slots=False) +class DiffusionDecoderSamplingConfig: + """ + Diffusion decoder sampling config + Args: + guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8. + sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02. + sigma (float): Initial noise level for the diffusion process. Defaults to 8. + num_steps (int): Number of denoising steps to perform. Defaults to 35. + overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2. + continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16. + continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8. + dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57. + """ + + guidance: float = 1.8 + sigma_min: float = 0.02 + sigma: float = 8 + num_steps: int = 15 + overlap: int = 2 + continuous_tokenizer_channel = 16 + continuous_tokenizer_spatial_compression_ratio = 8 + dd_train_num_video_frames: int = 57 + max_iter: int = 99 + fps: int = 24 + + +@attrs.define(slots=False) +class InferenceConfig: + """ + Inference config + Args: + model_config (ModelConfig): Model config + tokenizer_config (TokenizerConfig): Tokenizer config + ckpt_path (str): Path to the checkpoint + latent_shape (list): Shape of the latent + """ + + model_config: ModelConfig = None + tokenizer_config: TokenizerConfig = None + ckpt_path: str = "" + data_shape_config: DataShapeConfig = None + + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_val": None}, + {"data_shape_config": "video_shape_as_model_config"}, + {"eval_job": None}, + ] + ) diff --git a/inference_utils.py b/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d81404d99534c509d7ad93fbfb300f34721380dc --- /dev/null +++ b/inference_utils.py @@ -0,0 +1,732 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +from contextlib import contextmanager +from typing import List, NamedTuple, Optional, Tuple + +import einops +import imageio +import numpy as np +import torch +import torchvision.transforms.functional as transforms_F + +from .model_t2w import DiffusionT2WModel +from .model_v2w import DiffusionV2WModel +from .config_helper import get_config_module, override +from .utils_io import load_from_fileobj +from .misc import misc +from .df_config_config import make_config +from .log import log + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + +DEFAULT_AUGMENT_SIGMA = 0.001 + + +def add_common_arguments(parser): + """Add common command line arguments for text2world and video2world generation. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + + The arguments include: + - checkpoint_dir: Base directory containing model weights + - tokenizer_dir: Directory containing tokenizer weights + - video_save_name: Output video filename for single video generation + - video_save_folder: Output directory for batch video generation + - prompt: Text prompt for single video generation + - batch_input_path: Path to JSONL file with input prompts for batch video generation + - negative_prompt: Text prompt describing undesired attributes + - num_steps: Number of diffusion sampling steps + - guidance: Classifier-free guidance scale + - num_video_frames: Number of frames to generate + - height/width: Output video dimensions + - fps: Output video frame rate + - seed: Random seed for reproducibility + - Various model offloading flags + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--tokenizer_dir", + type=str, + default="Cosmos-1.0-Tokenizer-CV8x8x8", + help="Tokenizer weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument( + "--video_save_folder", + type=str, + default="outputs/", + help="Output folder for generating a batch of videos", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Path to a JSONL file of input prompts for generating a batch of videos", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special " + "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and " + "flickering. Overall, the video is of poor quality.", + help="Negative prompt for the video", + ) + parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") + parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value") + parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample") + parser.add_argument("--height", type=int, default=704, help="Height of video to sample") + parser.add_argument("--width", type=int, default=1280, help="Width of video to sample") + parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument( + "--disable_prompt_upsampler", + action="store_true", + help="Disable prompt upsampling", + ) + parser.add_argument( + "--offload_diffusion_transformer", + action="store_true", + help="Offload DiT after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload tokenizer after inference", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload text encoder model after inference", + ) + parser.add_argument( + "--offload_prompt_upsampler", + action="store_true", + help="Offload prompt upsampler after inference", + ) + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + + +def validate_args(args: argparse.Namespace, inference_type: str) -> None: + """Validate command line arguments for text2world and video2world generation.""" + assert inference_type in [ + "text2world", + "video2world", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + # Validate prompt/image/video args for single or batch generation + if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler): + assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided." + if inference_type == "video2world" and not args.batch_input_path: + assert ( + args.input_image_or_video_path + ), "--input_image_or_video_path must be provided for single video generation." + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + """Load a model checkpoint with non-strict matching, handling shape mismatches. + + Args: + model (torch.nn.Module): Model to load weights into + checkpoint_state_dict (dict): State dict from checkpoint + + Returns: + _IncompatibleKeys: Named tuple containing: + - missing_keys: Keys present in model but missing from checkpoint + - unexpected_keys: Keys present in checkpoint but not in model + - incorrect_shapes: Keys with mismatched tensor shapes + + The function handles special cases like: + - Uninitialized parameters + - Quantization observers + - TransformerEngine FP8 states + """ + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) + + +@contextmanager +def skip_init_linear(): + # skip init of nn.Linear + orig_reset_parameters = torch.nn.Linear.reset_parameters + torch.nn.Linear.reset_parameters = lambda x: x + xavier_uniform_ = torch.nn.init.xavier_uniform_ + torch.nn.init.xavier_uniform_ = lambda x: x + yield + torch.nn.Linear.reset_parameters = orig_reset_parameters + torch.nn.init.xavier_uniform_ = xavier_uniform_ + + +def load_model_by_config( + config_job_name, + config_file="projects/cosmos_video/config/config.py", + model_class=DiffusionT2WModel, +): + # TODO: We need to modify this for huggingface because the config file path is different + # config_module = get_config_module(config_file) # cosmos1/models/diffusion/config/config.py + # config = importlib.import_module(config_module).make_config() + if model_class in (DiffusionT2WModel, DiffusionV2WModel): + config = make_config() + else: + raise NotImplementedError("TODO: didn't implement autoregression") + + config = override(config, ["--", f"experiment={config_job_name}"]) + + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + + # Initialize model + with skip_init_linear(): + model = model_class(config.model) + return model + + +def load_network_model(model: DiffusionT2WModel, ckpt_path: str): + with skip_init_linear(): + model.set_up_model() + net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + log.debug(non_strict_load_model(model.model, net_state_dict)) + model.cuda() + + +def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): + with skip_init_linear(): + model.set_up_tokenizer(tokenizer_dir) + model.cuda() + + +def prepare_data_batch( + height: int, + width: int, + num_frames: int, + fps: int, + prompt_embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, +): + """Prepare input batch tensors for video generation. + + Args: + height (int): Height of video frames + width (int): Width of video frames + num_frames (int): Number of frames to generate + fps (int): Frames per second + prompt_embedding (torch.Tensor): Encoded text prompt embeddings + negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings + + Returns: + dict: Batch dictionary containing: + - video: Zero tensor of target video shape + - t5_text_mask: Attention mask for text embeddings + - image_size: Target frame dimensions + - fps: Target frame rate + - num_frames: Number of frames + - padding_mask: Frame padding mask + - t5_text_embeddings: Prompt embeddings + - neg_t5_text_embeddings: Negative prompt embeddings (if provided) + - neg_t5_text_mask: Mask for negative embeddings (if provided) + """ + # Create base data batch + data_batch = { + "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), + "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), + } + + # Handle text embeddings + + t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["t5_text_embeddings"] = t5_embed + + if negative_prompt_embedding is not None: + neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["neg_t5_text_embeddings"] = neg_t5_embed + data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() + + return data_batch + + +def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): + """Prepare complete input batch for video generation including latent dimensions. + + Args: + model: Diffusion model instance + prompt_embedding (torch.Tensor): Text prompt embeddings + negative_prompt_embedding (torch.Tensor): Negative prompt embeddings + height (int): Output video height + width (int): Output video width + fps (int): Output video frame rate + num_video_frames (int): Number of frames to generate + + Returns: + tuple: + - data_batch (dict): Complete model input batch + - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression + """ + raw_video_batch = prepare_data_batch( + height=height, + width=width, + num_frames=num_video_frames, + fps=fps, + prompt_embedding=prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(num_video_frames), + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + return raw_video_batch, state_shape + + +def generate_world_from_text( + model: DiffusionT2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, +): + """Generate video from text prompt using diffusion model. + + Args: + model (DiffusionT2WModel): Text-to-video diffusion model + state_shape (list[int]): Latent state dimensions [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Model input batch with embeddings + guidance (float): Classifier-free guidance scale + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for reproducibility + + Returns: + np.ndarray: Generated video frames [T,H,W,C], range [0,255] + + The function: + 1. Initializes random latent with maximum noise + 2. Performs guided diffusion sampling + 3. Decodes latents to pixel space + """ + x_sigma_max = ( + misc.arch_invariant_rand( + (1,) + tuple(state_shape), + torch.float32, + model.tensor_kwargs["device"], + seed, + ) + * model.sde.sigma_max + ) + + # Generate video + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + x_sigma_max=x_sigma_max, + ) + + return sample + + +def generate_world_from_video( + model: DiffusionV2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, + condition_latent: torch.Tensor, + num_input_frames: int, +) -> Tuple[np.array, list, list]: + """Generate video using a conditioning video/image input. + + Args: + model (DiffusionV2WModel): The diffusion model instance + state_shape (list[int]): Shape of the latent state [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Batch containing model inputs including text embeddings + guidance (float): Classifier-free guidance scale for sampling + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for generation + condition_latent (torch.Tensor): Latent tensor from conditioning video/image file + num_input_frames (int): Number of input frames + + Returns: + np.array: Generated video frames in shape [T,H,W,C], range [0,255] + """ + assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" + augment_sigma = DEFAULT_AUGMENT_SIGMA + + if condition_latent.shape[2] < state_shape[1]: + # Padding condition latent to state shape + b, c, t, h, w = condition_latent.shape + condition_latent = torch.cat( + [ + condition_latent, + condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), + ], + dim=2, + ).contiguous() + num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) + + x_sigma_max = ( + misc.arch_invariant_rand( + (1,) + tuple(state_shape), + torch.float32, + model.tensor_kwargs["device"], + seed, + ) + * model.sde.sigma_max + ) + + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + condition_latent=condition_latent, + num_condition_t=num_of_latent_condition, + condition_video_augment_sigma_in_inference=augment_sigma, + x_sigma_max=x_sigma_max, + ) + return sample + + +def read_video_or_image_into_frames_BCTHW( + input_path: str, + input_path_format: str = "mp4", + H: int = None, + W: int = None, + normalize: bool = True, + max_frames: int = -1, + also_return_fps: bool = False, +) -> torch.Tensor: + """Read video or image file and convert to tensor format. + + Args: + input_path (str): Path to input video/image file + input_path_format (str): Format of input file (default: "mp4") + H (int, optional): Height to resize frames to + W (int, optional): Width to resize frames to + normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) + max_frames (int): Maximum number of frames to read (-1 for all frames) + also_return_fps (bool): Whether to return fps along with frames + + Returns: + torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested + """ + log.debug(f"Reading video from {input_path}") + + loaded_data = load_from_fileobj(input_path, format=input_path_format) + frames, meta_data = loaded_data + if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): + frames = np.array(frames[0]) # HWC, [0,255] + if frames.shape[-1] > 3: # RGBA, set the transparent to white + # Separate the RGB and Alpha channels + rgb_channels = frames[..., :3] + alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] + + # Create a white background + white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB + + # Blend the RGB channels with the white background based on the alpha channel + frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( + np.uint8 + ) + frames = [frames] + fps = 0 + else: + fps = int(meta_data.get("fps")) + if max_frames != -1: + frames = frames[:max_frames] + input_tensor = np.stack(frames, axis=0) + input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") + if normalize: + input_tensor = input_tensor / 128.0 - 1.0 + input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW + log.debug(f"Raw data shape: {input_tensor.shape}") + if H is not None and W is not None: + input_tensor = transforms_F.resize( + input_tensor, + size=(H, W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) + if normalize: + input_tensor = input_tensor.to("cuda") + log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") + if also_return_fps: + return input_tensor, fps + return input_tensor + + +def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: + """This function computes the number of latent frames given the number of input frames. + Args: + model (DiffusionV2WModel): video generation model + num_input_frames (int): number of input frames + downsample_factor (int): downsample factor for temporal reduce + Returns: + int: number of latent frames + """ + num_latent_frames = ( + num_input_frames + // model.tokenizer.video_vae.pixel_chunk_duration + * model.tokenizer.video_vae.latent_chunk_duration + ) + if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1: + num_latent_frames += 1 + elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1: + assert ( + num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 + ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" + num_latent_frames += ( + 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor + ) + + return num_latent_frames + + +def create_condition_latent_from_input_frames( + model: DiffusionV2WModel, + input_frames: torch.Tensor, + num_frames_condition: int = 25, +): + """Create condition latent for video generation from input frames. + + Takes the last num_frames_condition frames from input as conditioning. + + Args: + model (DiffusionV2WModel): Video generation model + input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] + num_frames_condition (int): Number of frames to use for conditioning + + Returns: + tuple: (condition_latent, encode_input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - encode_input_frames (torch.Tensor): Padded input frames used for encoding + """ + B, C, T, H, W = input_frames.shape + num_frames_encode = ( + model.tokenizer.pixel_chunk_duration + ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1 + log.debug( + f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" + ) + + log.debug( + f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" + ) + + assert ( + input_frames.shape[2] >= num_frames_condition + ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" + assert ( + num_frames_encode >= num_frames_condition + ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" + + # Put the conditioal frames to the begining of the video, and pad the end with zero + condition_frames = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) + + log.debug( + f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" + ) + latent = model.encode(encode_input_frames) + return latent, encode_input_frames + + +def get_condition_latent( + model: DiffusionV2WModel, + input_image_or_video_path: str, + num_input_frames: int = 1, + state_shape: list[int] = None, +): + """Get condition latent from input image/video file. + + Args: + model (DiffusionV2WModel): Video generation model + input_image_or_video_path (str): Path to conditioning image/video + num_input_frames (int): Number of input frames for video2world prediction + + Returns: + tuple: (condition_latent, input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] + """ + if state_shape is None: + state_shape = model.state_shape + assert num_input_frames > 0, "num_input_frames must be greater than 0" + + H, W = ( + state_shape[-2] * model.tokenizer.spatial_compression_factor, + state_shape[-1] * model.tokenizer.spatial_compression_factor, + ) + + input_path_format = input_image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + input_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) + condition_latent = condition_latent.to(torch.bfloat16) + + return condition_latent + + +def check_input_frames(input_path: str, required_frames: int) -> bool: + """Check if input video/image has sufficient frames. + + Args: + input_path: Path to input video or image + required_frames: Number of required frames + + Returns: + np.ndarray of frames if valid, None if invalid + """ + if input_path.endswith((".jpg", ".jpeg", ".png")): + if required_frames > 1: + log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") + return False + return True # Let the pipeline handle image loading + # For video input + try: + vid = imageio.get_reader(input_path, "ffmpeg") + frame_count = vid.count_frames() + + if frame_count < required_frames: + log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") + return False + else: + return True + except Exception as e: + log.error(f"Error reading video file {input_path}: {e}") + return False diff --git a/instantiate.py b/instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..80112fd66106c65ee2e4cfdc375a53a131b7b57c --- /dev/null +++ b/instantiate.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc as abc +import dataclasses +import logging +from typing import Any + +import attrs + +from .registry import _convert_target_to_string, locate + +__all__ = ["dump_dataclass", "instantiate"] + + +def is_dataclass_or_attrs(target): + return dataclasses.is_dataclass(target) or attrs.has(target) + + +def dump_dataclass(obj: Any): + """ + Dump a dataclass recursively into a dict that can be later instantiated. + + Args: + obj: a dataclass object + + Returns: + dict + """ + assert dataclasses.is_dataclass(obj) and not isinstance( + obj, type + ), "dump_dataclass() requires an instance of a dataclass." + ret = {"_target_": _convert_target_to_string(type(obj))} + for f in dataclasses.fields(obj): + v = getattr(obj, f.name) + if dataclasses.is_dataclass(v): + v = dump_dataclass(v) + if isinstance(v, (list, tuple)): + v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] + ret[f.name] = v + return ret + + +def instantiate(cfg, *args, **kwargs): + """ + Recursively instantiate objects defined in dictionaries by + "_target_" and arguments. + + Args: + cfg: a dict-like object with "_target_" that defines the caller, and + other keys that define the arguments + args: Optional positional parameters pass-through. + kwargs: Optional named parameters pass-through. + + Returns: + object instantiated by cfg + """ + from omegaconf import DictConfig, ListConfig, OmegaConf + + if isinstance(cfg, ListConfig): + lst = [instantiate(x) for x in cfg] + return ListConfig(lst, flags={"allow_objects": True}) + if isinstance(cfg, list): + # Specialize for list, because many classes take + # list[objects] as arguments, such as ResNet, DatasetMapper + return [instantiate(x) for x in cfg] + + # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), + # instantiate it to the actual dataclass. + if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): + return OmegaConf.to_object(cfg) + + if isinstance(cfg, abc.Mapping) and "_target_" in cfg: + # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, + # but faster: https://github.com/facebookresearch/hydra/issues/1200 + cfg = {k: instantiate(v) for k, v in cfg.items()} + cls = cfg.pop("_target_") + cls = instantiate(cls) + + if isinstance(cls, str): + cls_name = cls + cls = locate(cls_name) + assert cls is not None, cls_name + else: + try: + cls_name = cls.__module__ + "." + cls.__qualname__ + except Exception: + # target could be anything, so the above could fail + cls_name = str(cls) + assert callable(cls), f"_target_ {cls} does not define a callable object" + try: + # override config with kwargs + instantiate_kwargs = {} + instantiate_kwargs.update(cfg) + instantiate_kwargs.update(kwargs) + return cls(*args, **instantiate_kwargs) + except TypeError: + logger = logging.getLogger(__name__) + logger.error(f"Error when instantiating {cls_name}!") + raise + return cfg # return as-is if don't know what to do diff --git a/lazy.py b/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..83585a8477999c5acc146ab790b0cac5959c2e3c --- /dev/null +++ b/lazy.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import builtins +import collections.abc as abc +import importlib +import inspect +import os +import uuid +from collections import OrderedDict +from contextlib import contextmanager +from dataclasses import is_dataclass +from typing import Any, Dict, List, Tuple, Union + +import attrs +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +from .file_io import PathManager +from .registry import _convert_target_to_string + +__all__ = ["LazyCall", "LazyConfig"] + + +def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: + return OrderedDict(sorted(d.items(), key=lambda x: x[0])) + + +def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: + if isinstance(obj, dict): + return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) + elif isinstance(obj, list): + return [sort_recursive(item) for item in obj] + return obj + + +yaml.add_representer(OrderedDict, dict_representer) + + +def get_default_params(cls_or_func): + if callable(cls_or_func): + # inspect signature for function + signature = inspect.signature(cls_or_func) + else: + # inspect signature for class + signature = inspect.signature(cls_or_func.__init__) + params = signature.parameters + default_params = { + name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty + } + return default_params + + +class LazyCall: + """ + Wrap a callable so that when it's called, the call will not be executed, + but returns a dict that describes the call. + + LazyCall object has to be called with only keyword arguments. Positional + arguments are not yet supported. + + Examples: + :: + # prevent huggingface from checking detectron2: from detectron2.config import instantiate, LazyCall + + layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) + layer_cfg.out_channels = 64 # can edit it afterwards + layer = instantiate(layer_cfg) + """ + + def __init__(self, target): + if not (callable(target) or isinstance(target, (str, abc.Mapping))): + raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") + self._target = target + + def __call__(self, **kwargs): + if is_dataclass(self._target) or attrs.has(self._target): + # omegaconf object cannot hold dataclass type + # https://github.com/omry/omegaconf/issues/784 + target = _convert_target_to_string(self._target) + else: + target = self._target + kwargs["_target_"] = target + + _final_params = get_default_params(self._target) + _final_params.update(kwargs) + + return DictConfig(content=_final_params, flags={"allow_objects": True}) + + +def _visit_dict_config(cfg, func): + """ + Apply func recursively to all DictConfig in cfg. + """ + if isinstance(cfg, DictConfig): + func(cfg) + for v in cfg.values(): + _visit_dict_config(v, func) + elif isinstance(cfg, ListConfig): + for v in cfg: + _visit_dict_config(v, func) + + +def _validate_py_syntax(filename): + # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + with PathManager.open(filename, "r") as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError(f"Config file {filename} has syntax error!") from e + + +def _cast_to_config(obj): + # if given a dict, return DictConfig instead + if isinstance(obj, dict): + return DictConfig(obj, flags={"allow_objects": True}) + return obj + + +_CFG_PACKAGE_NAME = "detectron2._cfg_loader" +""" +A namespace to put all imported config into. +""" + + +def _random_package_name(filename): + # generate a random package name when loading config files + return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) + + +@contextmanager +def _patch_import(): + """ + Enhance relative import statements in config files, so that they: + 1. locate files purely based on relative location, regardless of packages. + e.g. you can import file without having __init__ + 2. do not cache modules globally; modifications of module states has no side effect + 3. support other storage system through PathManager, so config files can be in the cloud + 4. imported dict are turned into omegaconf.DictConfig automatically + """ + old_import = builtins.__import__ + + def find_relative_file(original_file, relative_import_path, level): + # NOTE: "from . import x" is not handled. Because then it's unclear + # if such import should produce `x` as a python module or DictConfig. + # This can be discussed further if needed. + relative_import_err = """ +Relative import of directories is not allowed within config files. +Within a config file, relative import can only import other config files. +""".replace( + "\n", " " + ) + if not len(relative_import_path): + raise ImportError(relative_import_err) + + cur_file = os.path.dirname(original_file) + for _ in range(level - 1): + cur_file = os.path.dirname(cur_file) + cur_name = relative_import_path.lstrip(".") + for part in cur_name.split("."): + cur_file = os.path.join(cur_file, part) + if not cur_file.endswith(".py"): + cur_file += ".py" + if not PathManager.isfile(cur_file): + cur_file_no_suffix = cur_file[: -len(".py")] + if PathManager.isdir(cur_file_no_suffix): + raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) + else: + raise ImportError( + f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist." + ) + return cur_file + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + if ( + # Only deal with relative imports inside config files + level != 0 + and globals is not None + and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) + ): + cur_file = find_relative_file(globals["__file__"], name, level) + _validate_py_syntax(cur_file) + spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) + module = importlib.util.module_from_spec(spec) + module.__file__ = cur_file + with PathManager.open(cur_file) as f: + content = f.read() + exec(compile(content, cur_file, "exec"), module.__dict__) + for name in fromlist: # turn imported dict into DictConfig automatically + val = _cast_to_config(module.__dict__[name]) + module.__dict__[name] = val + return module + return old_import(name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + yield new_import + builtins.__import__ = old_import + + +class LazyConfig: + """ + Provide methods to save, load, and overrides an omegaconf config object + which may contain definition of lazily-constructed objects. + """ + + @staticmethod + def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Load a config file. + + Args: + filename: absolute path or relative path w.r.t. the current working directory + keys: keys to load and return. If not given, return all keys + (whose values are config objects) in a dict. + """ + has_keys = keys is not None + filename = filename.replace("/./", "/") # redundant + if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: + raise ValueError(f"Config file {filename} has to be a python or yaml file.") + if filename.endswith(".py"): + _validate_py_syntax(filename) + + with _patch_import(): + # Record the filename + module_namespace = { + "__file__": filename, + "__package__": _random_package_name(filename), + } + with PathManager.open(filename) as f: + content = f.read() + # Compile first with filename to: + # 1. make filename appears in stacktrace + # 2. make load_rel able to find its parent's (possibly remote) location + exec(compile(content, filename, "exec"), module_namespace) + + ret = module_namespace + else: + with PathManager.open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + + if has_keys: + if isinstance(keys, str): + return _cast_to_config(ret[keys]) + else: + return tuple(_cast_to_config(ret[a]) for a in keys) + else: + if filename.endswith(".py"): + # when not specified, only load those that are config objects + ret = DictConfig( + { + name: _cast_to_config(value) + for name, value in ret.items() + if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") + }, + flags={"allow_objects": True}, + ) + return ret diff --git a/lazy_config_init.py b/lazy_config_init.py new file mode 100644 index 0000000000000000000000000000000000000000..fa48b6ddc9171218d21c6e4e1d20e267e8b51015 --- /dev/null +++ b/lazy_config_init.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from omegaconf import DictConfig, OmegaConf + +from .instantiate import instantiate +from .lazy import LazyCall, LazyConfig +from .omegaconf_patch import to_object + +OmegaConf.to_object = to_object + +PLACEHOLDER = None +LazyDict = DictConfig + +__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] + + +DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py + + +def fixup_module_metadata(module_name, namespace, keys=None): + """ + Fix the __qualname__ of module members to be their exported api name, so + when they are referenced in docs, sphinx can find them. Reference: + https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 + """ + if not DOC_BUILDING: + return + seen_ids = set() + + def fix_one(qualname, name, obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualitied + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + if keys is None: + keys = namespace.keys() + for objname in keys: + if not objname.startswith("_"): + obj = namespace[objname] + fix_one(objname, objname, obj) + + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/log.py b/log.py new file mode 100644 index 0000000000000000000000000000000000000000..132302a1ab96d20cf9a3792958ab04b744cc3c6d --- /dev/null +++ b/log.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +import sys +from typing import Any, Optional + +import torch.distributed as dist +from loguru._logger import Core, Logger + +RANK0_ONLY = True +LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") + +logger = Logger( + core=Core(), + exception=None, + depth=1, + record=False, + lazy=False, + colors=False, + raw=False, + capture=True, + patchers=[], + extra={}, +) + +atexit.register(logger.remove) + + +def _add_relative_path(record: dict[str, Any]) -> None: + start = os.getcwd() + record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) + + +*options, _, extra = logger._options # type: ignore +logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore + + +def init_loguru_stdout() -> None: + logger.remove() + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + sys.stdout, + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + filter=_rank0_only_filter, + ) + + +def get_machine_format() -> str: + node_id = os.environ.get("NGC_ARRAY_INDEX", "0") + num_nodes = int(os.environ.get("NGC_ARRAY_SIZE", "1")) + machine_format = "" + rank = 0 + if dist.is_available(): + if not RANK0_ONLY and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + machine_format = ( + f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " + ) + return machine_format + + +def get_message_format() -> str: + message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" + return message_format + + +def _rank0_only_filter(record: Any) -> bool: + is_rank0 = record["extra"].get("rank0_only", True) + if _get_rank() == 0 and is_rank0: + return True + if not is_rank0: + record["message"] = f"[RANK {_get_rank()}]" + record["message"] + return not is_rank0 + + +class log(): + + @staticmethod + def trace(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) + + @staticmethod + def debug(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) + + @staticmethod + def info(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) + + @staticmethod + def success(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) + + @staticmethod + def warning(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) + + @staticmethod + def error(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) + + @staticmethod + def critical(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) + + @staticmethod + def exception(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) + + +def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +# Execute at import time. +init_loguru_stdout() diff --git a/misc.py b/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae48798eca2bbeb8003eec1ac6148cde9ef9e6c --- /dev/null +++ b/misc.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import functools +import json +import random +import time +from contextlib import ContextDecorator +from typing import Any, Callable, TypeVar + +from .log import log +import numpy as np +import termcolor +import torch + +from .distributed import get_rank + + +class misc(): + + @staticmethod + def to( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Any: + """Recursively cast data into the specified device, dtype, and/or memory_format. + + The input data can be a tensor, a list of tensors, a dict of tensors. + See the documentation for torch.Tensor.to() for details. + + Args: + data (Any): Input data. + device (str | torch.device): GPU device (default: None). + dtype (torch.dtype): data type (default: None). + memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). + + Returns: + data (Any): Data cast to the specified device, dtype, and/or memory_format. + """ + assert ( + device is not None or dtype is not None or memory_format is not None + ), "at least one of device, dtype, memory_format should be specified" + if isinstance(data, torch.Tensor): + is_cpu = (isinstance(device, str) and device == "cpu") or ( + isinstance(device, torch.device) and device.type == "cpu" + ) + data = data.to( + device=device, + dtype=dtype, + memory_format=memory_format, + non_blocking=(not is_cpu), + ) + return data + elif isinstance(data, collections.abc.Mapping): + return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) + else: + return data + + + @staticmethod + def serialize(data: Any) -> Any: + """Serialize data by hierarchically traversing through iterables. + + Args: + data (Any): Input data. + + Returns: + data (Any): Serialized data. + """ + if isinstance(data, collections.abc.Mapping): + return type(data)({key: serialize(data[key]) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([serialize(elem) for elem in data]) + else: + try: + json.dumps(data) + except TypeError: + data = str(data) + return data + + + @staticmethod + def set_random_seed(seed: int, by_rank: bool = False) -> None: + """Set random seed. This includes random, numpy, Pytorch. + + Args: + seed (int): Random seed. + by_rank (bool): if true, each GPU will use a different random seed. + """ + if by_rank: + seed += get_rank() + log.info(f"Using random seed {seed}.") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # sets seed on the current CPU & all GPUs + + + @staticmethod + def arch_invariant_rand( + shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None + ): + """Produce a GPU-architecture-invariant randomized Torch tensor. + + Args: + shape (list or tuple of ints): Output tensor shape. + dtype (torch.dtype): Output tensor type. + device (torch.device): Device holding the output. + seed (int): Optional randomization seed. + + Returns: + tensor (torch.tensor): Randomly-generated tensor. + """ + # Create a random number generator, optionally seeded + rng = np.random.RandomState(seed) + + # # Generate random numbers using the generator + random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution + + # Convert to torch tensor and return + return torch.from_numpy(random_array).to(dtype=dtype, device=device) + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class timer(ContextDecorator): # noqa: N801 + """Simple timer for timing the execution of code. + + It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. + + Example: + def func_a(): + time.sleep(1) + with timer("func_a"): + func_a() + + @timer("func_b) + def func_b(): + time.sleep(1) + func_b() + """ + + def __init__(self, context: str, debug: bool = False): + self.context = context + self.debug = debug + + def __enter__(self) -> None: + self.tic = time.time() + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + time_spent = time.time() - self.tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + + def __call__(self, func: T) -> T: + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + tic = time.time() + result = func(*args, **kwargs) + time_spent = time.time() - tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + return result + + return wrapper # type: ignore + + +class Color: + """A convenience class to colorize strings in the console. + + Example: + import + print("This is {Color.red('important')}.") + """ + + @staticmethod + def red(x: str) -> str: + return termcolor.colored(str(x), color="red") + + @staticmethod + def green(x: str) -> str: + return termcolor.colored(str(x), color="green") + + @staticmethod + def cyan(x: str) -> str: + return termcolor.colored(str(x), color="cyan") + + @staticmethod + def yellow(x: str) -> str: + return termcolor.colored(str(x), color="yellow") diff --git a/mm_projector.py b/mm_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..ee54c961498ff108a92fe621e9322649f7ad891b --- /dev/null +++ b/mm_projector.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multimodal projector to connect vision encoder / tokenizer with the LLM.""" + +from typing import Any, Optional + +import torch +import torch.nn as nn + + +class DownSampleBlock(nn.Module): + """Downsample block.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + """ + Performs the forward pass of the downsample block. + + Args: + x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings. + Shape: (b, seq_len, c). + + Returns: + torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4). + """ + vit_embeds = x + # Get h and w as the sqrt of seq length. This assumes that the input is square-shaped. + h = w = int(vit_embeds.shape[1] ** 0.5) + b = vit_embeds.shape[0] + vit_embeds = vit_embeds.reshape(b, h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs spatial downsampling while increasing the number of channels. + + Args: + x (torch.Tensor): The input tensor reshaped to a 2D grid. + Shape: (b, h, w, c) + + Returns: + torch.Tensor: The output tensor after the spatial downsampling. + Shape: (b, h/2, w/2, c*4) + """ + b, h, w, c = x.size() + # If w or h is odd, pad a column or a row of zeros. + if h % 2 == 1: + x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() + b, h, w, c = x.size() + if w % 2 == 1: + x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() + b, h, w, c = x.size() + # 2x spatial downsampling, 4x channel increasing. + x = x.view(b, h, int(w / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(b, int(h / 2), int(w / 2), int(c * 4)) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + +class MultimodalProjector(nn.Module): + """Multimodal projector.""" + + def __init__( + self, + mm_projector_type: str, + in_dim: int, + out_dim: Optional[int] = None, + **kwargs: Any, + ): + super().__init__() + if out_dim is None: + out_dim = in_dim + if mm_projector_type == "identity": + self.projector = nn.Identity() + elif mm_projector_type == "linear": + self.projector = nn.Linear(in_dim, out_dim) + elif mm_projector_type == "mlp": + self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) + elif mm_projector_type == "mlp_downsample": + self.projector = nn.Sequential( + DownSampleBlock(), + nn.LayerNorm(in_dim * 4), + nn.Linear(in_dim * 4, out_dim), + nn.GELU(), + nn.Linear(out_dim, out_dim), + ) + else: + raise ValueError(f"Unknown projector type: {mm_projector_type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) diff --git a/model_config.py b/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc83c93a31aca1ee351e57bf1773d69e9081d9e --- /dev/null +++ b/model_config.py @@ -0,0 +1,421 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, List, Optional + +from .ar_configs_base_model import ModelConfig +from .ar_config_tokenizer import ( + TextTokenizerConfig, + TokenizerConfig, + VideoTokenizerConfig, + create_discrete_video_fsq_tokenizer_state_dict_config, +) +from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer +from .ar_tokenizer_text_tokenizer import TextTokenizer +from .log import log +from .lazy_config_init import LazyCall as L + +# Common architecture specifications +BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336} +COSMOS_ARCHITECTURES = { + "4b": { + "n_layers": 16, + "dim": 4096, + "n_heads": 32, + }, + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "head_dim": 128, + }, +} + +COSMOS_YARN_CONFIG = { + "original_latent_shape": [3, 40, 64], + "apply_yarn": True, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, + "yarn_scale": 2, +} + +# Llama3 architecture specifications for different model sizes +LLAMA3_ARCHITECTURES = { + "8b": { + "n_layers": 32, + "dim": 4096, + "n_heads": 32, + "ffn_hidden_size": 14336, + }, +} +# Llama3.1 uses YaRN for long context support (context of 128k tokens) +LLAMA_YARN_CONFIG = { + "apply_yarn": True, + "yarn_scale": 8, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, +} + +# Mistral architecture specifications for different model sizes +MISTRAL_ARCHITECTURES = { + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "ffn_hidden_size": 14336, + "head_dim": 128, + }, +} + +PIXTRAL_VISION_ARCHITECTURES = { + "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"}, +} + + +def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict: + """ + Get the model architecture specifications for the given model size, model family and pretrained status. + + Args: + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral" + pretrained (bool): Whether to load pretrained weights. + + Returns: + dict: A dictionary containing the model architecture specifications. + """ + arch_specs = copy.deepcopy(BASE_CONFIG) + model_size = model_size.lower() + if model_family.startswith("cosmos"): + arch_specs.update(COSMOS_ARCHITECTURES[model_size]) + elif model_family.startswith("llama"): + arch_specs.update(LLAMA3_ARCHITECTURES[model_size]) + elif model_family in ["mistral", "pixtral"]: + arch_specs.update(MISTRAL_ARCHITECTURES[model_size]) + if model_family == "pixtral": + arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size]) + else: + raise ValueError(f"Model family {model_family} is not supported.") + + if pretrained: + if model_family == "cosmos": + if model_size == "12b": + arch_specs.update(COSMOS_YARN_CONFIG) + log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}") + else: + pass + elif model_family in ["llama", "llama3"]: + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 8192, + "vocab_size": 128256, + } + arch_specs.update(pretrained_specs) + elif model_family == "llama3.1": + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 131072, + "original_seq_len": 8192, + "vocab_size": 128256, + **LLAMA_YARN_CONFIG, + } + arch_specs.update(pretrained_specs) + elif model_family == "mistral": + assert model_size == "12b", "We only support Mistral-Nemo-12B model." + pretrained_specs = { + "rope_theta": 1000000, + "max_seq_len": 128000, + "vocab_size": 131072, + } + arch_specs.update(pretrained_specs) + elif model_family == "pixtral": + assert model_size == "12b", "We only support Pixtral 12B model." + pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072} + arch_specs.update(pretrained_specs) + else: + raise ValueError(f"Model family {model_family} doesn't have a pretrained config.") + + return arch_specs + + +def create_text_model_config( + model_ckpt_path: str, + tokenizer_path: str, + model_family: str = "mistral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_seq_len: int = None, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + pytorch_rope_version: str = None, +) -> dict: + """Create a text model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_path (str): Path to the tokenizer folder. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc. + is_instruct_model (bool): Whether the model is an instruct model. + inference (bool): Whether to create the model for inference. + max_seq_len (int): Maximum sequence length. + max_batch_size (int): Maximum batch size. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + rope_dim=rope_dim, + **model_arch_specs, + ) + + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(TextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ), + data_key="text", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="text_only", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_vision_language_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + model_family: str = "pixtral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + max_seq_len: int = None, + vision_encoder_in_channels: int = 3, + fuse_qkv: bool = False, + pytorch_rope_version: str = None, +) -> dict: + """Create a vision-language model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_ckpt_path (str): Path to the tokenizer checkpoint. + model_family (str): Model family. Choices: "pixtral". + model_size (str): Model size. Choices: "12b". + is_instruct_model (bool): Whether the model is an instruct model. + rope_dim (str): RoPE dimension. Choices: "1D". + add_special_tokens (bool): Whether to add special tokens. + max_seq_len (int): Maximum sequence length. + vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4. + fuse_qkv (bool): Whether to fuse the QKV linear layers. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + rope_dim=rope_dim, + vision_encoder_in_channels=vision_encoder_in_channels, + fuse_qkv=fuse_qkv, + **model_arch_specs, + ) + # Vision-language tokenizer + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(ImageTextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + image_processor_path=tokenizer_ckpt_path, + tokenizer_path=tokenizer_ckpt_path, + ), + data_key="image_text_interleaved", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="image_text_interleaved", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_video2world_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + model_family: str = "cosmos", + model_size: str = "4b", + pixel_chunk_duration: int = 9, + num_video_frames: int = 36, + compression_ratio: List[int] = [8, 16, 16], + original_seq_len: int = 8192, + num_condition_latents_t: int = 1, + num_tokens_to_ignore: int = -1, + batch_size: int = 2, + video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config, + rope_dim: str = "3D", + add_special_tokens: bool = True, + video_height: int = 384, + video_width: int = 640, + use_qk_normalization: bool = True, + insert_cross_attn: bool = False, + insert_cross_attn_every_k_layers: int = 1, + context_dim: int = 1024, + training_type: str = "video_to_video", + pad_to_multiple_of: Optional[int] = 64, + vocab_size: int = 64000, + apply_abs_pos_emb: bool = False, +) -> dict: + """Create a video-to-world model config. + Args: + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "8b", "3b". + pixel_chunk_duration (int): Number of frames in each chunk. + num_video_frames (int): Number of video frames. + compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8]. + original_seq_len (int): Original sequence length. + apply_yarn (bool): Whether to apply YaRN for long context scaling. + yarn_beta_fast (Optional[int]): Fast beta for YaRN. + yarn_beta_slow (Optional[int]): Slow beta for YaRN. + yarn_scale (Optional[int]): Scale factor for ctx extension. + use_qk_normalization (bool): Whether to use Query-Key normalization. + training_type (str): Type of training task. + batch_size (int): Batch size. + video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config + video_tokenizer_version (str): Version of the video tokenizer. + num_condition_latents_t (int): Number of conditioning latent channels + num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence + video_height (int): Height of the video frame. Defaults to 384. + video_width (int): Width of the video frame. Defaults to 640. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE. + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + vocab_size (int): Vocabulary size. + apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings. + Returns: + dict: A dictionary containing the model configuration representing the model object, can be instantiated. + """ + assert ( + pixel_chunk_duration % compression_ratio[0] == 1 + ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})" + latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1 + latent_height = video_height // compression_ratio[1] + latent_width = video_width // compression_ratio[2] + # Do some math to compute the video latent shape and sequence length + assert ( + num_video_frames % pixel_chunk_duration == 0 + ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}" + video_latent_shape = [ + num_video_frames // pixel_chunk_duration * latent_chunk_duration, + latent_height, + latent_width, + ] + # product of video_latent_shape + num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2] + if add_special_tokens: + seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3 + seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64 + # for text to video, we need to add token to indicate the start of the video + elif training_type == "text_to_video": + seq_len = num_token_video_latent + 1 + else: + seq_len = num_token_video_latent + + if seq_len % pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + + # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss + # If num_tokens_to_ignore is specified, use it. + # Else compute it from num_condition_latents_t + if num_tokens_to_ignore < 0: + num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t + if not add_special_tokens and num_condition_latents_t > 0: + # If there are no special tokens (bov), do a -1 so that you can compute the loss + # from the first token of the next chunk + num_tokens_to_ignore -= 1 + + model_config = ModelConfig( + video_height=video_height, + video_width=video_width, + max_seq_len=seq_len, + max_batch_size=batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=use_qk_normalization, + vocab_size=64000, + original_seq_len=original_seq_len, + video_latent_shape=video_latent_shape, + num_video_frames=num_video_frames, + rope_dim=rope_dim, + pad_to_multiple_of=pad_to_multiple_of, + insert_cross_attn=insert_cross_attn, + insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers, + context_dim=context_dim, + apply_abs_pos_emb=apply_abs_pos_emb, + **model_arch_specs, + ) + + video_tokenizer_config = video_tokenizer_config_creator( + tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio + ) + tokenizer_config = TokenizerConfig( + text_tokenizer=None, + video_tokenizer=VideoTokenizerConfig( + config=video_tokenizer_config, + data_key="video", + tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token. + tokenize_here=True, + max_seq_len=num_token_video_latent, + vocab_size=vocab_size, + ), + seq_len=seq_len, + training_type=training_type, + add_special_tokens=add_special_tokens, + pad_to_multiple_of=pad_to_multiple_of, + ) + return model_config, tokenizer_config diff --git a/model_t2w.py b/model_t2w.py new file mode 100644 index 0000000000000000000000000000000000000000..cecf4aab16cc96b15a342483adce231c43c01a73 --- /dev/null +++ b/model_t2w.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Tuple + +import torch +from torch import Tensor + +from .conditioner import CosmosCondition +from .batch_ops import batch_mul +from .denoiser_scaling import EDMScaling +from .res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from .diffusion_types import DenoisePrediction +from .blocks import FourierFeatures +from .pretrained_vae import BaseVAE +from .misc import misc, Color, timer +from .instantiate import instantiate as lazy_instantiate +from .log import log + + +class EDMSDE: + def __init__( + self, + sigma_max: float, + sigma_min: float, + ): + self.sigma_max = sigma_max + self.sigma_min = sigma_min + + +class DiffusionT2WModel(torch.nn.Module): + """Text-to-world diffusion model that generates video frames from text descriptions. + + This model implements a diffusion-based approach for generating videos conditioned on text input. + It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, + and classifier-free guidance. + """ + + def __init__(self, config): + """Initialize the diffusion model. + + Args: + config: Configuration object containing model parameters and architecture settings + """ + super().__init__() + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.config = config + + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.debug(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002) + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + self.tokenizer = None + self.model = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + @property + def logvar(self): + return self.model.logvar + + def set_up_tokenizer(self, tokenizer_dir: str): + self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) + self.tokenizer.load_weights(tokenizer_dir) + if hasattr(self.tokenizer, "reset_dtype"): + self.tokenizer.reset_dtype() + + @timer("DiffusionModel: set_up_model") + def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): + """Initialize the core model components including network, conditioner and logvar.""" + self.model = self.build_model() + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + def build_model(self) -> torch.nn.ModuleDict: + """Construct the model's neural network components. + + Returns: + ModuleDict containing the network, conditioner and logvar components + """ + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """Encode input state into latent representation using VAE. + + Args: + state: Input tensor to encode + + Returns: + Encoded latent representation scaled by sigma_data + """ + return self.tokenizer.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """Decode latent representation back to pixel space using VAE. + + Args: + latent: Latent tensor to decode + + Returns: + Decoded tensor in pixel space + """ + return self.tokenizer.decode(latent / self.sigma_data) + + def setup_data_key(self) -> None: + """Configure input data keys for video and image data.""" + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function processes the input data batch through a conditioning workflow to obtain + conditioned and unconditioned states. It then defines a nested function `x0_fn` which + applies denoising on an input `noise_x` at a given noise level `sigma`. + + Args: + data_batch: A batch of data used for conditioning. Format should align with conditioner + guidance: Scalar value that modulates influence of conditioned vs unconditioned state + is_negative_prompt: Use negative prompt t5 in uncondition if true + + Returns: + A function `x0_fn(noise_x, sigma)` that takes noise_x and sigma, returns x0 prediction + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + + return raw_x0 + + return x0_fn + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + ) -> Tensor: + """Generate samples from a data batch using diffusion sampling. + + This function generates samples from either image or video data batches using diffusion sampling. + It handles both conditional and unconditional generation with classifier-free guidance. + + Args: + data_batch (Dict): Raw data batch from the training data loader + guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. + seed (int, optional): Random seed for reproducibility. Defaults to 1. + state_shape (Tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. + n_sample (int | None, optional): Number of samples to generate. Defaults to None. + is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. + num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. + solver_option (COMMON_SOLVER_OPTIONS, optional): Differential equation solver option. Defaults to "2ab" (multistep solver). + x_sigma_max (Optional[torch.Tensor], optional): Initial noisy tensor. If None, randomly initialized. Defaults to None. + sigma_max (float | None, optional): Maximum noise level. Uses self.sde.sigma_max if None. Defaults to None. + + Returns: + Tensor: Generated samples after diffusion sampling + """ + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + if sigma_max is None: + sigma_max = self.sde.sigma_max + else: + log.info("Using provided sigma_max for diffusion sampling.") + if x_sigma_max is None: + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * sigma_max + ) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option + ) + + return samples diff --git a/model_v2w.py b/model_v2w.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cb117d95fd9509077028c8bf2015957047653d --- /dev/null +++ b/model_v2w.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple, Union + +from .log import log +import torch +from torch import Tensor + +from .conditioner import VideoExtendCondition +from .config_base_conditioner import VideoCondBoolConfig +from .batch_ops import batch_mul +from .model_t2w import DiffusionT2WModel +from .misc import misc, Color, timer + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty + xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +class DiffusionV2WModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """Augments the conditional frames with noise during inference. + + Args: + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config + gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + sigma (Tensor): noise level for the generation region + seed (int): random seed for reproducibility + Returns: + VideoExtendCondition: updated condition object + condition_video_augment_sigma: sigma for the condition region, feed to the network + augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W + + """ + + # Inference only, use fixed sigma for the condition region + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. + # This is achieved by setting all region as `generation`, i.e. value=0 + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition.condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + + # Now apply the augment_sigma to the gt_latent + + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + + augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None] + + _, _, c_in_augment, _ = self.scaling(sigma=augment_sigma) + + # Multiply the whole latent with c_in_augment + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + + # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed: int = 1, + ) -> VideoDenoisePrediction: + """Denoises input tensor using conditional video generation. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + seed (int): Random seed for reproducibility + Returns: + VideoDenoisePrediction containing: + - x0: Denoised prediction + - eps: Noise prediction + - logvar: Log variance of noise prediction + - xt: Input before c_in multiplication + - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth + """ + + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + denoise_pred = super().denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + x_sigma_max: Optional[torch.Tensor] = None, + ) -> Tensor: + """Generates video samples conditioned on input frames. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + seed: Random seed for reproducibility + state_shape: Shape of output tensor (defaults to model's state shape) + n_sample: Number of samples to generate (defaults to batch size) + is_negative_prompt: Whether to use negative prompting + num_steps: Number of denoising steps + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_video_augment_sigma_in_inference: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + x_sigma_max: Maximum noise level tensor + + Returns: + Generated video samples tensor + """ + + if n_sample is None: + input_key = self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed=seed, + ) + if x_sigma_max is None: + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed: int = 1, + ) -> Callable: + """Creates denoising function for conditional video generation. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_video_augment_sigma_in_inference: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + seed: Random seed for reproducibility + + Returns: + Function that takes noisy input and noise level and returns denoised prediction + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed=seed, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed=seed, + ).x0_pred_replaced + + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Adds conditioning masks to VideoExtendCondition object. + + Creates binary indicators and input masks for conditional video generation. + + Args: + latent_state: Input latent tensor (B,C,T,H,W) + condition: VideoExtendCondition object to update + num_condition_t: Number of frames to condition on + + Returns: + Updated VideoExtendCondition with added masks: + - condition_video_indicator: Binary tensor marking condition regions + - condition_video_input_mask: Input mask for network + - gt_latent: Ground truth latent tensor + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.debug( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + return condition diff --git a/multi_step.py b/multi_step.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0887e98f9b1b1bb03534ceeeeeb763ac3ba4b2 --- /dev/null +++ b/multi_step.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Impl of multistep methods to solve the ODE in the diffusion model. +""" + +from typing import Callable, List, Tuple + +import torch + +from .runge_kutta import reg_x0_euler_step, res_x0_rk2_step + + +def order2_fn( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + impl the second order multistep method in https://arxiv.org/pdf/2308.02157 + Adams Bashforth approach! + """ + if x0_preds: + x0_s1, s1 = x0_preds[0] + x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) + else: + x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] + return x_t, [(x0_s, s)] + + +# key: method name, value: method function +# key: order + algorithm name +MULTISTEP_FNs = { + "2ab": order2_fn, +} + + +def get_multi_step_fn(name: str) -> Callable: + if name in MULTISTEP_FNs: + return MULTISTEP_FNs[name] + methods = "\n\t".join(MULTISTEP_FNs.keys()) + raise RuntimeError("Only support multistep method\n" + methods) + + +def is_multi_step_fn_supported(name: str) -> bool: + """ + Check if the multistep method is supported. + """ + return name in MULTISTEP_FNs diff --git a/omegaconf_patch.py b/omegaconf_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..39dca42a0a71383de919b750cedf2606faae206d --- /dev/null +++ b/omegaconf_patch.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Union + +from omegaconf import OmegaConf +from omegaconf.base import DictKeyType, SCMode +from omegaconf.dictconfig import DictConfig # pragma: no cover + + +def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: + """ + Converts an OmegaConf configuration object to a native Python container (dict or list), unless + the configuration is specifically created by LazyCall, in which case the original configuration + is returned directly. + + This function serves as a modification of the original `to_object` method from OmegaConf, + preventing DictConfig objects created by LazyCall from being automatically converted to Python + dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended + structure and behavior. + + Differences from OmegaConf's original `to_object`: + - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. + + Reference: + - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 + + Args: + cfg (Any): The OmegaConf configuration object to convert. + + Returns: + Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if + `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. + + Examples: + >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) + >>> to_object(cfg) + DictConfig({"key": "value", "_target_": "Model"}) + + >>> cfg = DictConfig({"list": [1, 2, 3]}) + >>> to_object(cfg) + {'list': [1, 2, 3]} + """ + if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): + return cfg + + return OmegaConf.to_container( + cfg=cfg, + resolve=True, + throw_on_missing=True, + enum_to_str=False, + structured_config_mode=SCMode.INSTANTIATE, + ) diff --git a/position_embedding.py b/position_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6c5f27652254282772c5b11f1a2ef61aadfe30 --- /dev/null +++ b/position_embedding.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import nn + +from .attention import normalize +from .timm import trunc_normal_ + + +class VideoPositionEmb(nn.Module): + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) + self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) + self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) diff --git a/presets.py b/presets.py new file mode 100644 index 0000000000000000000000000000000000000000..85b9dd38e57e1822239259ae8f44c1bc0ca05e7f --- /dev/null +++ b/presets.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +from .aegis import Aegis +from .blocklist import Blocklist +from .guardrail_core import GuardrailRunner +from .face_blur_filter import RetinaFaceFilter +from .video_content_safety_filter import VideoContentSafetyFilter +from .log import log + + +class presets(): + + @staticmethod + def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the text guardrail runner.""" + blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist") + aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis") + return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)]) + + + @staticmethod + def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the video guardrail runner.""" + video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter") + retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth") + return GuardrailRunner( + safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)], + postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)], + ) + + + @staticmethod + def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool: + """Run the text guardrail on the prompt, checking for content safety. + + Args: + prompt: The text prompt. + guardrail_runner: The text guardrail runner. + + Returns: + bool: Whether the prompt is safe. + """ + is_safe, message = guardrail_runner.run_safety_check(prompt) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return is_safe + + + @staticmethod + def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None: + """Run the video guardrail on the frames, checking for content safety and applying face blur. + + Args: + frames: The frames of the generated video. + guardrail_runner: The video guardrail runner. + + Returns: + The processed frames if safe, otherwise None. + """ + is_safe, message = guardrail_runner.run_safety_check(frames) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return None + + frames = guardrail_runner.postprocess(frames) + return frames diff --git a/pretrained_vae.py b/pretrained_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..05f6f3e548bab062b6f46c0f12377a502bde0dbf --- /dev/null +++ b/pretrained_vae.py @@ -0,0 +1,606 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +from einops import rearrange +from torch.nn.modules import Module + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.name = name + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "image_mean_std.pt"), weights_only=True) + + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, latent_ch, is_image, is_bf16) + + def load_encoder(self, vae_dir: str) -> None: + """ + Load the encoder from the remote store. + """ + self.encoder = torch.load(os.path.join(vae_dir, "encoder.jit"), weights_only=True) + + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, vae_dir: str) -> None: + """ + Load the decoder from the remote store. + """ + self.decoder = torch.load(os.path.join(vae_dir, "decoder.jit"), weights_only=True) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=True) + + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__( + pixel_chunk_duration, + temporal_compression_factor, + max_enc_batch_size, + max_dec_batch_size, + ) + super(BasePretrainedVideoTokenizer, self).__init__( + name, + latent_ch, + False, + is_bf16, + ) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + + def load_weights(self, vae_dir: str): + self.video_vae.register_mean_std(vae_dir) + + self.video_vae.load_decoder(vae_dir) + self.video_vae.load_encoder(vae_dir) diff --git a/registry.py b/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7c09eb428a97927d5f0407e2328a3f43afbf38fc --- /dev/null +++ b/registry.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pydoc +from typing import Any + +""" +`locate` provide ways to map a string (typically found +in config files) to callable objects. +""" + +__all__ = ["locate"] + + +def _convert_target_to_string(t: Any) -> str: + """ + Inverse of ``locate()``. + + Args: + t: any object with ``__module__`` and ``__qualname__`` + """ + module, qualname = t.__module__, t.__qualname__ + + # Compress the path to this object, e.g. ``module.submodule._impl.class`` + # may become ``module.submodule.class``, if the later also resolves to the same + # object. This simplifies the string, and also is less affected by moving the + # class implementation. + module_parts = module.split(".") + for k in range(1, len(module_parts)): + prefix = ".".join(module_parts[:k]) + candidate = f"{prefix}.{qualname}" + try: + if locate(candidate) is t: + return candidate + except ImportError: + pass + return f"{module}.{qualname}" + + +def locate(name: str) -> Any: + """ + Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, + such as "module.submodule.class_name". + + Raise Exception if it cannot be found. + """ + obj = pydoc.locate(name) + + # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly + # by pydoc.locate. Try a private function from hydra. + if obj is None: + try: + # from hydra.utils import get_method - will print many errors + from hydra.utils import _locate + except ImportError as e: + raise ImportError(f"Cannot dynamically locate object {name}!") from e + else: + obj = _locate(name) # it raises if fails + + return obj diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a629f6c927f66937cbc7b7058142a0a494387ec8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Please keep requirements sorted alphabetically +av +better-profanity +git+https://github.com/NVlabs/Pytorch_Retinaface.git@b843f45 +hydra-core +imageio[ffmpeg] +iopath +loguru +mediapy +nltk +peft +pillow +sentencepiece +termcolor +transformers==4.45.0 diff --git a/res_sampler.py b/res_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c531bcd5f8304985c68ce09350ef48297217cad7 --- /dev/null +++ b/res_sampler.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general framework for various sampling algorithm from a diffusion model. +Impl based on +* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 +* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. +Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ + adding new step function in get_runge_kutta_fn or get_multi_step_fn. +""" + +import math +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import attrs +import torch + +from .multi_step import get_multi_step_fn, is_multi_step_fn_supported +from .runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported +from .config import make_freezable + +COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] + + +@make_freezable +@attrs.define(slots=False) +class SolverConfig: + is_multi: bool = False + rk: str = "2mid" + multistep: str = "2ab" + # following parameters control stochasticity, see EDM paper + # BY default, we use deterministic with no stochasticity + s_churn: float = 0.0 + s_t_max: float = float("inf") + s_t_min: float = 0.05 + s_noise: float = 1.0 + + +@make_freezable +@attrs.define(slots=False) +class SolverTimestampConfig: + nfe: int = 50 + t_min: float = 0.002 + t_max: float = 80.0 + order: float = 7.0 + is_forward: bool = False # whether generate forward or backward timestamps + + +@make_freezable +@attrs.define(slots=False) +class SamplerConfig: + solver: SolverConfig = attrs.field(factory=SolverConfig) + timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) + sample_clean: bool = True # whether run one last step to generate clean image + + +def get_rev_ts( + t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False +) -> torch.Tensor: + """ + Generate a sequence of reverse time steps. + + Args: + t_min (float): The minimum time value. + t_max (float): The maximum time value. + num_steps (int): The number of time steps to generate. + ts_order (Union[int, float]): The order of the time step progression. + is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. + + Returns: + torch.Tensor: A tensor containing the generated time steps in reverse or forward order. + + Raises: + ValueError: If `t_min` is not less than `t_max`. + TypeError: If `ts_order` is not an integer or float. + """ + if t_min >= t_max: + raise ValueError("t_min must be less than t_max") + + if not isinstance(ts_order, (int, float)): + raise TypeError("ts_order must be an integer or float") + + step_indices = torch.arange(num_steps + 1, dtype=torch.float64) + time_steps = ( + t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) + ) ** ts_order + + if is_forward: + return time_steps.flip(dims=(0,)) + + return time_steps + + +class Sampler(torch.nn.Module): + def __init__(self, cfg: Optional[SamplerConfig] = None): + super().__init__() + if cfg is None: + cfg = SamplerConfig() + self.cfg = cfg + + @torch.no_grad() + def forward( + self, + x0_fn: Callable, + x_sigma_max: torch.Tensor, + num_steps: int = 35, + sigma_min: float = 0.002, + sigma_max: float = 80, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, + solver_option: str = "2ab", + ) -> torch.Tensor: + in_dtype = x_sigma_max.dtype + + def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: + return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) + + is_multistep = is_multi_step_fn_supported(solver_option) + is_rk = is_runge_kutta_fn_supported(solver_option) + assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" + + solver_cfg = SolverConfig( + s_churn=S_churn, + s_t_max=S_max, + s_t_min=S_min, + s_noise=S_noise, + is_multi=is_multistep, + rk=solver_option, + multistep=solver_option, + ) + timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) + sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) + + return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) + + @torch.no_grad() + def _forward_impl( + self, + denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noisy_input_B_StateShape: torch.Tensor, + sampler_cfg: Optional[SamplerConfig] = None, + callback_fns: Optional[List[Callable]] = None, + ) -> torch.Tensor: + """ + Internal implementation of the forward pass. + + Args: + denoiser_fn: Function to denoise the input. + noisy_input_B_StateShape: Input tensor with noise. + sampler_cfg: Configuration for the sampler. + callback_fns: List of callback functions to be called during sampling. + + Returns: + torch.Tensor: Denoised output tensor. + """ + sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg + solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) + num_timestamps = sampler_cfg.timestamps.nfe // solver_order + + sigmas_L = get_rev_ts( + sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order + ).to(noisy_input_B_StateShape.device) + + denoised_output = differential_equation_solver( + denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns + )(noisy_input_B_StateShape) + + if sampler_cfg.sample_clean: + # Override denoised_output with fully denoised version + ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) + denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) + + return denoised_output + + +def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: + """ + Implements a for loop with a function. + + Args: + lower: Lower bound of the loop (inclusive). + upper: Upper bound of the loop (exclusive). + body_fun: Function to be applied in each iteration. + init_val: Initial value for the loop. + + Returns: + The final result after all iterations. + """ + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val + + +def differential_equation_solver( + x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + sigmas_L: torch.Tensor, + solver_cfg: SolverConfig, + callback_fns: Optional[List[Callable]] = None, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Creates a differential equation solver function. + + Args: + x0_fn: Function to compute x0 prediction. + sigmas_L: Tensor of sigma values with shape [L,]. + solver_cfg: Configuration for the solver. + callback_fns: Optional list of callback functions. + + Returns: + A function that solves the differential equation. + """ + num_step = len(sigmas_L) - 1 + + if solver_cfg.is_multi: + update_step_fn = get_multi_step_fn(solver_cfg.multistep) + else: + update_step_fn = get_runge_kutta_fn(solver_cfg.rk) + + eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) + + def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: + """ + Samples from the differential equation. + + Args: + input_xT_B_StateShape: Input tensor with shape [B, StateShape]. + + Returns: + Output tensor with shape [B, StateShape]. + """ + ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) + + def step_fn( + i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + input_x_B_StateShape, x0_preds = state + sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] + + # algorithm 2: line 4-6 + if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: + hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 + input_x_B_StateShape = input_x_B_StateShape + ( + hat_sigma_cur_0**2 - sigma_cur_0**2 + ).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) + sigma_cur_0 = hat_sigma_cur_0 + + if solver_cfg.is_multi: + x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds + ) + else: + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn + ) + + if callback_fns: + for callback_fn in callback_fns: + callback_fn(**locals()) + + return output_x_B_StateShape, x0_preds + + x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) + return x_at_eps + + return sample_fn diff --git a/retinaface_utils.py b/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1bc4c8a22a942736ae6b73a4ebb21da4980adc --- /dev/null +++ b/retinaface_utils.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from pytorch_retinaface.utils.nms.py_cpu_nms import py_cpu_nms + +from .log import log + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k): + """Filter boxes based on confidence score and remove overlapping boxes using NMS.""" + # Keep detections with confidence above threshold + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # Sort by confidence and keep top K detections + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # Run non-maximum-suppression (NMS) to remove overlapping boxes + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + dets = dets[:keep_top_k, :] + boxes = dets[:, :-1] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs +def decode_batch(loc, priors, variances): + """Decode batched locations from predictions using priors and variances. + + Args: + loc (tensor): Batched location predictions for loc layers. + Shape: [batch_size, num_priors, 4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors, 4] + variances: (list[float]): Variances of prior boxes. + + Return: + Decoded batched bounding box predictions + Shape: [batch_size, num_priors, 4] + """ + batch_size = loc.size(0) + priors = priors.unsqueeze(0).expand(batch_size, -1, -1) + + boxes = torch.cat( + ( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]), + ), + dim=2, + ) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + log.debug("Missing keys:{}".format(len(missing_keys))) + log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys))) + log.debug("Used keys:{}".format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" + return True + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _remove_prefix(state_dict, prefix): + """Old version of the model is stored with all names of parameters sharing common prefix 'module.'""" + log.debug("Removing prefix '{}'".format(prefix)) + + def f(x): + return x.split(prefix, 1)[-1] if x.startswith(prefix) else x + + return {f(key): value for key, value in state_dict.items()} + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def load_model(model, pretrained_path, load_to_cpu): + log.debug("Loading pretrained model from {}".format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load( + pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True + ) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.") + else: + pretrained_dict = _remove_prefix(pretrained_dict, "module.") + _check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model diff --git a/runge_kutta.py b/runge_kutta.py new file mode 100644 index 0000000000000000000000000000000000000000..ecffde890072dccffd2ffe67d1534044414266df --- /dev/null +++ b/runge_kutta.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple + +import torch + +from .batch_ops import batch_mul + + +def phi1(t: torch.Tensor) -> torch.Tensor: + """ + Compute the first order phi function: (exp(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi1 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return (torch.expm1(t) / t).to(dtype=input_dtype) + + +def phi2(t: torch.Tensor) -> torch.Tensor: + """ + Compute the second order phi function: (phi1(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi2 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) + + +def res_x0_rk2_step( + x_s: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + x0_s: torch.Tensor, + s1: torch.Tensor, + x0_s1: torch.Tensor, +) -> torch.Tensor: + """ + Perform a residual-based 2nd order Runge-Kutta step. + + Args: + x_s: Current state tensor. + t: Target time tensor. + s: Current time tensor. + x0_s: Prediction at current time. + s1: Intermediate time tensor. + x0_s1: Prediction at intermediate time. + + Returns: + Tensor: Updated state tensor. + + Raises: + AssertionError: If step size is too small. + """ + s = -torch.log(s) + t = -torch.log(t) + m = -torch.log(s1) + + dt = t - s + assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + + c2 = (m - s) / dt + phi1_val, phi2_val = phi1(-dt), phi2(-dt) + + # Handle edge case where t = s = m + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) + + +def reg_x0_euler_step( + x_s: torch.Tensor, + s: torch.Tensor, + t: torch.Tensor, + x0_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on x0 prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_s: Prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current prediction. + """ + coef_x0 = (s - t) / s + coef_xs = t / s + return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s + + +def reg_eps_euler_step( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on epsilon prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + eps_s: Epsilon prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction. + """ + return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s) + + +def rk1_euler( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a first-order Runge-Kutta (Euler) step. + + Recommended for diffusion models with guidance or model undertrained + Usually more stable at the cost of a bit slower convergence. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + x0_s = x0_fn(x_s, s) + return reg_x0_euler_step(x_s, s, t, x0_s) + + +def rk2_mid_stable( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a stable second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, _ = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + return reg_x0_euler_step(x_s, s, t, x0_s1) + + +def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + + return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1 + + +def rk_2heun_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + eps_s = batch_mul(1.0 / s, x_t - x0_s) + x0_t = x0_fn(x_t, t) + eps_t = batch_mul(1.0 / t, x_t - x0_t) + + avg_eps = (eps_s + eps_t) / 2 + + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +def rk_2heun_edm( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based no EDM second order Heun method + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + x0_t = x0_fn(x_t, t) + + avg_x0 = (x0_s + x0_t) / 2 + + return reg_x0_euler_step(x_s, s, t, avg_x0) + + +def rk_3kutta_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive third-order Runge-Kutta step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + c2, c3 = 0.5, 1.0 + a31, a32 = -1.0, 2.0 + b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6 + + delta = t - s + + s1 = c2 * delta + s + s2 = c3 * delta + s + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + eps_s = batch_mul(1.0 / s, x_s - x0_s) + x0_s1 = x0_fn(x_s1, s1) + eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1) + + _eps = a31 * eps_s + a32 * eps_s1 + x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps) + + x0_s2 = x0_fn(x_s2, s2) + eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2) + + avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2 + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +# key : order + name +RK_FNs = { + "1euler": rk1_euler, + "2mid": rk2_mid, + "2mid_stable": rk2_mid_stable, + "2heun_edm": rk_2heun_edm, + "2heun_naive": rk_2heun_naive, + "3kutta_naive": rk_3kutta_naive, +} + + +def get_runge_kutta_fn(name: str) -> Callable: + """ + Get the specified Runge-Kutta function. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + Callable: The specified Runge-Kutta function. + + Raises: + RuntimeError: If the specified method is not supported. + """ + if name in RK_FNs: + return RK_FNs[name] + methods = "\n\t".join(RK_FNs.keys()) + raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}") + + +def is_runge_kutta_fn_supported(name: str) -> bool: + """ + Check if the specified Runge-Kutta function is supported. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + bool: True if the method is supported, False otherwise. + """ + return name in RK_FNs diff --git a/sampling.py b/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..b9719442e0fa94cea283348a7732e0413c5a7234 --- /dev/null +++ b/sampling.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch + +from .ar_transformer import Transformer + + +def sample_top_p(logits, temperature, top_p, return_probs: bool = False): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + logits (torch.Tensor): Logits of the probability distribution. + temperature (float): Temperature for sampling. + top_p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1) + # Sort the probabilities in descending order and get their indices. + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + # Compute the cumulative sum of the sorted probabilities. + probs_sum = torch.cumsum(probs_sort, dim=-1) + # Create a mask where the cumulative probability exceeds the threshold p. + mask = probs_sum - probs_sort > top_p + # Set the probabilities that exceed the threshold to 0. + probs_sort[mask] = 0.0 + # Renormalize the remaining probabilities so they sum to 1. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + # Sample from the renormalized probability distribution. + # next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64) + # Gather the indices of the sampled tokens. + next_token = torch.gather(probs_idx, -1, next_token) + if return_probs: + # Initialize a tensor for unsorted probabilities + probs_unsorted = torch.zeros_like(probs_sort) + # Scatter the sorted probabilities back to their original order + probs_unsorted.scatter_(-1, probs_idx, probs_sort) + else: + probs_unsorted = None + return next_token, probs_unsorted + + +def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int): + """ + Multinomial sampling without a cuda synchronization. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype) + + +def logits_to_probs( + logits, + temperature: float = 1.0, + top_k: Optional[int] = None, +): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None): + """ + Sample from the logits using top-k sampling. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + # logits: [batch_size, seq_len, vocab_size] + if temperature == 0.0: + idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) + probs = None + else: + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, + input_pos: torch.Tensor, + tokens: torch.Tensor = None, + token_embeddings: torch.Tensor = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> torch.Tensor: + logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs) + # Only top-p or top-k can be provided + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p)[0] + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k)[0] + + +def decode_one_token( + model: Transformer, + tokens: torch.Tensor, + input_pos: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decode a single token from the autoregressive model. + """ + logits = model(tokens=tokens, input_pos=input_pos, **kwargs) + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p) + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + stop_tokens: torch.Tensor = None, + temperature: float = 1.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + return_probs: bool = False, + decode_one_token_function=decode_one_token, + **kwargs, +): + """ + Decode n tokens from the autoregressive model. + Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + new_tokens, new_probs = [], [] + batch_size = cur_token.shape[0] + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if stop_tokens is not None: + # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch + eos_reached = torch.tensor([False] * batch_size, device="cuda") + for t in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token_function( + model, + tokens=cur_token, + input_pos=input_pos, + temperature=temperature, + top_k=top_k, + top_p=top_p, + **kwargs, + ) + input_pos += 1 + if stop_tokens is not None and len(stop_tokens) > 0: + eos_reached = eos_reached | (torch.isin(next_token, stop_tokens)) + if eos_reached.all(): + break + new_tokens.append(next_token.clone()) + if return_probs: + new_probs.append(next_prob.clone()) + cur_token = next_token.clone() + + if return_probs: + return new_tokens, new_probs + else: + return new_tokens diff --git a/t5_text_encoder.py b/t5_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7bebf08cef2869c85553980bf81851635dd74f7e --- /dev/null +++ b/t5_text_encoder.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple, Union + +import torch +import transformers +from transformers import T5EncoderModel, T5TokenizerFast + +from .log import log + +transformers.logging.set_verbosity_error() + + +class CosmosT5TextEncoder(torch.nn.Module): + """Handles T5 text encoding operations.""" + + def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"): + """Initializes the T5 tokenizer and encoder. + + Args: + model_name: The name of the T5 model to use. + device: The device to use for computations. + """ + super().__init__() + try: + self.tokenizer = T5TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir) + self.text_encoder = T5EncoderModel.from_pretrained(model_name, cache_dir=cache_dir).to(device) + except Exception as e: + log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}") + self.tokenizer = T5TokenizerFast.from_pretrained(model_name) + self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device) + self.text_encoder.eval() + self.device = device + + @torch.inference_mode() + def encode_prompts( + self, prompts: Union[str, List[str]], max_length: int = 512 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encodes text prompts into hidden state representations using a T5 encoder. + + This function tokenizes the input prompts, processes them through a T5 text encoder, + and returns the last hidden states. The encoded outputs beyond the actual sequence + length are zero-padded. All prompts in a batch are padded to max_length. + + Args: + prompts: Input text to encode. Can be a single string or a list of strings. + max_length: Maximum sequence length for tokenization and padding. Longer + sequences will be truncated. Defaults to 512. + return_mask: If True, returns the attention mask along with encoded text. + Defaults to False. + + Returns: + If return_mask is False: + torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size). + If return_mask is True: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Encoded text embeddings of shape (batch_size, max_length, hidden_size) + - Attention mask of shape (batch_size, max_length) as boolean tensor + + Raises: + ValueError: If the input prompts list is empty. + + Example: + >>> encoder = CosmosT5TextEncoder() + >>> prompts = ["Hello world", "Another example"] + >>> embeddings = encoder.encode_prompts(prompts, max_length=128) + """ + if isinstance(prompts, str): + prompts = [prompts] + + if not prompts: + raise ValueError("The input prompt list is empty.") + + batch_encoding = self.tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + input_ids = batch_encoding.input_ids.to(self.device) + attn_mask = batch_encoding.attention_mask.to(self.device) + + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask) + + encoded_text = outputs.last_hidden_state + lengths = attn_mask.sum(dim=1).cpu() + + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + return encoded_text, attn_mask diff --git a/text2world.py b/text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..b14d5a740e4ea4bbd3cdc5ac9151aed5bbb45bc0 --- /dev/null +++ b/text2world.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from .log import log +import torch + +from .inference_utils import add_common_arguments, validate_args +from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline +from .misc import misc, Color, timer +from .utils_io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Text to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add text2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-1.0-Diffusion-7B-Text2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-1.0-Diffusion-7B-Text2World", + "Cosmos-1.0-Diffusion-14B-Text2World", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + + parser.add_argument( + "--word_limit_to_skip_upsampler", + type=int, + default=250, + help="Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value", + ) + + return parser.parse_args() + + +def demo(cfg): + """Run text-to-world generation demo. + + This function handles the main text-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts from input + - Generating videos from text prompts + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(cfg.seed) + inference_type = "text2world" + validate_args(cfg, inference_type) + + # Initialize text2world generation model pipeline + pipeline = DiffusionText2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=cfg.checkpoint_dir, + checkpoint_name=cfg.diffusion_transformer_dir, + prompt_upsampler_dir=cfg.prompt_upsampler_dir, + enable_prompt_upsampler=not cfg.disable_prompt_upsampler, + offload_network=cfg.offload_diffusion_transformer, + offload_tokenizer=cfg.offload_tokenizer, + offload_text_encoder_model=cfg.offload_text_encoder_model, + offload_prompt_upsampler=cfg.offload_prompt_upsampler, + offload_guardrail_models=cfg.offload_guardrail_models, + guidance=cfg.guidance, + num_steps=cfg.num_steps, + height=cfg.height, + width=cfg.width, + fps=cfg.fps, + num_video_frames=cfg.num_video_frames, + seed=cfg.seed, + ) + + # Handle multiple prompts if prompt file is provided + if cfg.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(cfg.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": cfg.prompt}] + + os.makedirs(cfg.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None: + log.critical("Prompt is missing, skipping world generation.") + continue + + # Generate video + generated_output = pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + video, prompt = generated_output + + if cfg.batch_input_path: + video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=cfg.fps, + H=cfg.height, + W=cfg.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/text2world_hf.py b/text2world_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..21a30f036ce3955456eca210c367237a1b09be95 --- /dev/null +++ b/text2world_hf.py @@ -0,0 +1,140 @@ +import os +import argparse +import torch +from transformers import PreTrainedModel, PretrainedConfig + +from .inference_utils import add_common_arguments, validate_args +from .world_generation_pipeline import DiffusionText2WorldGenerationPipeline +from .log import log +from .misc import misc, Color, timer +from .utils_io import read_prompts_from_file, save_video +from .df_config_config import attrs # this makes huggingface to download the file +from .download_diffusion import main as download_diffusion + + +# custom config class +class DiffusionText2WorldConfig(PretrainedConfig): + model_type = "DiffusionText2World" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Text2World") + self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Cosmos-1.0-Prompt-Upsampler-12B-Text2World") + self.word_limit_to_skip_upsampler = kwargs.get("word_limit_to_skip_upsampler", 250) + self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints") + self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8") + self.video_save_name = kwargs.get("video_save_name", "output") + self.video_save_folder = kwargs.get("video_save_folder", "outputs/") + self.prompt = kwargs.get("prompt", None) + self.batch_input_path = kwargs.get("batch_input_path", None) + self.negative_prompt = kwargs.get("negative_prompt", None) + self.num_steps = kwargs.get("num_steps", 35) + self.guidance = kwargs.get("guidance", 7) + self.num_video_frames = kwargs.get("num_video_frames", 121) + self.height = kwargs.get("height", 704) + self.width = kwargs.get("width", 1280) + self.fps = kwargs.get("fps", 24) + self.seed = kwargs.get("seed", 1) + self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False) + self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False) + self.offload_tokenizer = kwargs.get("offload_tokenizer", False) + self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False) + self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False) + self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False) + + +# custom model calss +class DiffusionText2World(PreTrainedModel): + config_class = DiffusionText2WorldConfig + + def __init__(self, config=DiffusionText2WorldConfig()): + super().__init__(config) + torch.enable_grad(False) + self.config = config + inference_type = "text2world" + config.prompt = 1 # this is to hack args validation, maybe find a better way + validate_args(config, inference_type) + del config.prompt + self.pipeline = DiffusionText2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=config.checkpoint_dir, + checkpoint_name=config.diffusion_transformer_dir, + prompt_upsampler_dir=config.prompt_upsampler_dir, + enable_prompt_upsampler=not config.disable_prompt_upsampler, + offload_network=config.offload_diffusion_transformer, + offload_tokenizer=config.offload_tokenizer, + offload_text_encoder_model=config.offload_text_encoder_model, + offload_prompt_upsampler=config.offload_prompt_upsampler, + offload_guardrail_models=config.offload_guardrail_models, + guidance=config.guidance, + num_steps=config.num_steps, + height=config.height, + width=config.width, + fps=config.fps, + num_video_frames=config.num_video_frames, + seed=config.seed, + ) + + # modifed from text2world.py demo function + def forward(self, prompt): + cfg = self.config + # Handle multiple prompts if prompt file is provided + if cfg.batch_input_path: + log.info(f"Reading batch inputs from path: {cfg.batch_input_path}") + prompts = read_prompts_from_file(cfg.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": prompt}] + + os.makedirs(cfg.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None: + log.critical("Prompt is missing, skipping world generation.") + continue + + # Generate video + generated_output = self.pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + video, prompt = generated_output + + if cfg.batch_input_path: + video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=cfg.fps, + H=cfg.height, + W=cfg.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + def save_pretrained(self, save_directory, **kwargs): + # We don't save anything, but need this function to override + pass + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs["config"] + other_args = kwargs.copy() + other_args.pop("config") + config.update(other_args) + model_sizes = ["7B",] if "7B" in config.diffusion_transformer_dir else ["14B",] + model_types = ["Text2World",] + download_diffusion(model_types, model_sizes, config.checkpoint_dir) + model = cls(config) + return model \ No newline at end of file diff --git a/text2world_prompt_upsampler_inference.py b/text2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2507b29fc56aff1c898fe6f2dbe652641dd32a --- /dev/null +++ b/text2world_prompt_upsampler_inference.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Cosmos-1.0-Prompt-Upsampler-12B-Text2World. +Command: + PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py + +""" +import argparse +import os +import re + +from .model_config import create_text_model_config +from .ar_model import AutoRegressiveModel +from .inference import chat_completion +from .presets import presets as guardrail_presets +from .log import log + + +def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel: + model_config, tokenizer_config = create_text_model_config( + model_ckpt_path=os.path.join(checkpoint_dir, "model.pt"), + tokenizer_path=os.path.join(checkpoint_dir), + model_family="mistral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + rope_dim="1D", + add_special_tokens=True, + max_seq_len=1024, + pytorch_rope_version="v1", + ) + log.debug(f"Text prompt upsampler model config: {model_config}") + + # Create and return a LLM instance + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def run_chat_completion(model: AutoRegressiveModel, input: str, temperature: float = 0.01): + """ + text2world prompt upsampler model is finetuned for chat. + During training, the context window for the initial prompt upsampler models is 512 tokens. For inference, we set max_seq_len to 1024 to accommodate longer inputs. + Setting `max_gen_len` is optional as the finetuned models can naturally determine when to stop generating. + """ + + dialogs = [[{"role": "user", "content": f"Upsample the short caption to a long caption: {str(input)}"}]] + + results = chat_completion( + model, + dialogs, + max_gen_len=512, + temperature=temperature, + top_p=None, + top_k=None, + logprobs=False, + ) + upsampled_prompt = str(clean_text(results[0]["generation"]["content"])) + return upsampled_prompt + + +def clean_text(text: str) -> str: + """Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace.""" + # Replace all variations of newlines with a space + text = text.replace("\n", " ").replace("\r", " ") + + # Use a regex to find sections of the form '- **...**' + pattern = r"(- \*\*)(.*?)(\*\*)" + + def replacement(match: re.Match[str]) -> str: + content = match.group(2) # The text inside - ** and ** + words = re.findall(r"\w+", content) + if len(words) < 10: + # If fewer than 10 words, remove the entire '- **...**' portion + return "" + else: + # If 10 or more words, keep the entire section as it is + return match.group(0) + + text = re.sub(pattern, replacement, text) + + # Remove common prefixes + prefixes = ["Caption:", "#####", "####", "- ", "* ", ","] + for prefix in prefixes: + # lstrip(prefix) won't strip entire strings, but character sets. + # For more reliable prefix removal, do: + if text.startswith(prefix): + text = text[len(prefix) :].lstrip() + + # Remove extra spaces + text = " ".join(text.split()) + + # Strip any remaining leading/trailing punctuation, whitespace, and quotes + text = text.strip(' -,*:"\'"“”') + + return text + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument("--input", type=str, default="A dog is playing with a ball.") + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--guardrail_dir", + type=str, + default="Cosmos-1.0-Guardrail", + help="Guardrail weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner( + os.path.join(args.checkpoint_dir, args.guardrail_dir) + ) + is_safe = guardrail_presets.run_text_guardrail(args.input, guardrail_runner) + if not is_safe: + log.critical("Input text prompt is not safe.") + return + + prompt_upsampler = create_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + upsampled_prompt = run_chat_completion(prompt_upsampler, args.input, temperature=args.temperature) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/timm.py b/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe6dfe88a4afead7de85133fb74bbee181d7f49 --- /dev/null +++ b/timm.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/utils_io.py b/utils_io.py new file mode 100644 index 0000000000000000000000000000000000000000..c877aa41fd6b90638281f048bac23fc8214b84be --- /dev/null +++ b/utils_io.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from io import BytesIO +from typing import Dict, List + +import imageio +import numpy as np + + +def read_prompts_from_file(prompt_file: str) -> List[Dict[str, str]]: + """Read prompts from a JSONL file where each line is a dict with 'prompt' key and optionally 'visual_input' key. + + Args: + prompt_file (str): Path to JSONL file containing prompts + + Returns: + List[Dict[str, str]]: List of prompt dictionaries + """ + prompts = [] + with open(prompt_file, "r") as f: + for line in f: + prompt_dict = json.loads(line.strip()) + prompts.append(prompt_dict) + return prompts + + +def save_video(video, fps, H, W, video_save_quality, video_save_path): + """Save video frames to file. + + Args: + grid (np.ndarray): Video frames array [T,H,W,C] + fps (int): Frames per second + H (int): Frame height + W (int): Frame width + video_save_quality (int): Video encoding quality (0-10) + video_save_path (str): Output video file path + """ + kwargs = { + "fps": fps, + "quality": video_save_quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{W}x{H}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(video_save_path, video, "mp4", **kwargs) + + +def load_from_fileobj(filepath: str, format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + with open(filepath, "rb") as f: + value = f.read() + with BytesIO(value) as f: + f.seek(0) + video_reader = imageio.get_reader(f, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() diff --git a/video2world.py b/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..507c585d81832b51b053ba701b2c1b9f4298b50a --- /dev/null +++ b/video2world.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from .log import log +import torch + +from .inference_utils import add_common_arguments, check_input_frames, validate_args +from .world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline +from .misc import misc, Color, timer +from .utils_io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-1.0-Diffusion-7B-Video2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-1.0-Diffusion-7B-Video2World", + "Cosmos-1.0-Diffusion-14B-Video2World", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=1, + help="Number of input frames for video2world prediction", + choices=[1, 9], + ) + + return parser.parse_args() + + +def demo(cfg): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(cfg.seed) + inference_type = "video2world" + validate_args(cfg, inference_type) + + # Initialize video2world generation model pipeline + pipeline = DiffusionVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=cfg.checkpoint_dir, + checkpoint_name=cfg.diffusion_transformer_dir, + prompt_upsampler_dir=cfg.prompt_upsampler_dir, + enable_prompt_upsampler=not cfg.disable_prompt_upsampler, + offload_network=cfg.offload_diffusion_transformer, + offload_tokenizer=cfg.offload_tokenizer, + offload_text_encoder_model=cfg.offload_text_encoder_model, + offload_prompt_upsampler=cfg.offload_prompt_upsampler, + offload_guardrail_models=cfg.offload_guardrail_models, + guidance=cfg.guidance, + num_steps=cfg.num_steps, + height=cfg.height, + width=cfg.width, + fps=cfg.fps, + num_video_frames=cfg.num_video_frames, + seed=cfg.seed, + num_input_frames=cfg.num_input_frames, + ) + + # Handle multiple prompts if prompt file is provided + if cfg.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(cfg.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}] + + os.makedirs(cfg.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and cfg.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_or_video_path = input_dict.get("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_or_video_path, cfg.num_input_frames): + continue + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + negative_prompt=cfg.negative_prompt, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + if cfg.batch_input_path: + video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") + prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=cfg.fps, + H=cfg.height, + W=cfg.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/video2world_prompt_upsampler_inference.py b/video2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6c8b93d92c8fbd84a9e2047c15b2d29cc98cc0 --- /dev/null +++ b/video2world_prompt_upsampler_inference.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Pixtral-12B. +Command: + PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py + +""" + +import argparse +import os +from math import ceil + +from PIL import Image + +from .model_config import create_vision_language_model_config +from .ar_model import AutoRegressiveModel +from .inference import chat_completion +from .presets import presets as guardrail_presets +from .log import log +from .utils_io import load_from_fileobj + + +def create_vlm_prompt_upsampler( + checkpoint_dir: str, tokenizer_ckpt_path: str = "mistral-community/pixtral-12b" +) -> AutoRegressiveModel: + """ + Load the fine-tuned pixtral model for SimReady. + If pixtral_ckpt is not provided, use the pretrained checkpoint. + """ + model_ckpt_path = os.path.join(checkpoint_dir, "model.pt") + model_config, tokenizer_config = create_vision_language_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_family="pixtral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + max_seq_len=4300, + pytorch_rope_version="v1", + ) + # during instantiate, the weights will be downloaded (if not already cached) and loaded + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def resize_image(image: Image.Image, max_size: int = 1024) -> Image.Image: + """ + Ensure that the image is no larger than max_size in both dimensions. + """ + image_width, image_height = image.size + max_width, max_height = max_size, max_size + ratio = max(image_width / max_width, image_height / max_height) + if ratio > 1: + image = image.resize((ceil(image_width / ratio), ceil(image_height / ratio))) + return image + + +def prepare_dialog(image_or_video_path: str) -> list[dict]: + if image_or_video_path.endswith(".mp4"): + video_np, _ = load_from_fileobj(image_or_video_path, format="mp4") + image_frame = video_np[-1] + image = Image.fromarray(image_frame) + else: + image: Image.Image = Image.open(image_or_video_path) + + image = resize_image(image, max_size=1024) + prompt = """\ +Your task is to transform a given prompt into a refined and concise video description, no more than 150 words. +Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video. + """.strip() + + return [ + { + "role": "user", + "content": "[IMG]\n" + prompt, + "images": [image], + } + ] + + +def run_chat_completion(pixtral: AutoRegressiveModel, dialog: list[dict], **inference_args) -> str: + default_args = { + "max_gen_len": 400, + "temperature": 0, + "top_p": 0.9, + "logprobs": False, + "compile_sampling": False, + "compile_prefill": False, + } + default_args.update(inference_args) + results = chat_completion( + pixtral, + [dialog], + **default_args, + ) + assert len(results) == 1 + upsampled_prompt = str(results[0]["generation"]["content"]) + return upsampled_prompt + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument( + "--image_or_video_path", type=str, default="cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg" + ) + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p value for top-p sampling") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--guardrail_dir", + type=str, + default="Cosmos-1.0-Guardrail", + help="Guardrail weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner( + os.path.join(args.checkpoint_dir, args.guardrail_dir) + ) + + pixtral = create_vlm_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + dialog = prepare_dialog(args.image_or_video_path) + upsampled_prompt = run_chat_completion( + pixtral, + dialog, + max_gen_len=400, + temperature=args.temperature, + top_p=args.top_p, + logprobs=False, + ) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/video_content_safety_filter.py b/video_content_safety_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d111f81d1abdd9629a8c87a710e078e906da65 --- /dev/null +++ b/video_content_safety_filter.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from typing import Iterable, Tuple, Union + +from .log import log +import torch +from PIL import Image + +from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner +from .guardrail_io_utils import get_video_filepaths, read_video +from .video_content_safety_filter_model import ModelConfig, VideoSafetyModel +from .video_content_safety_filter_vision_encoder import SigLIPEncoder +from .misc import misc, Color, timer + +DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter" + + +# Define the class index to class name mapping for multi-class classification +CLASS_IDX_TO_NAME = { + 0: "Safe", + 1: "Sexual_Content", + 2: "Violence", + 3: "Drugs", + 4: "Child_Abuse", + 5: "Hate_and_Harassment", + 6: "Self-Harm", +} + + +class VideoContentSafetyFilter(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.device = device + self.dtype = torch.float32 + + # Initialize the SigLIP encoder + self.encoder = SigLIPEncoder(checkpoint_dir=checkpoint_dir, device=device, dtype=self.dtype) + + # Use ModelConfig directly for inference configuration + model_config = ModelConfig(input_size=1152, num_classes=7) + + # Load the multi-class classifier + self.model = VideoSafetyModel(model_config) + safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt") + checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True) + self.model.load_state_dict(checkpoint["model"]) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def __infer(self, pil_image: Image.Image) -> int: + """Infer the class of the image.""" + image_embs = self.encoder.encode_image(pil_image) + logits = self.model.network(image_embs) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + predicted_class = torch.argmax(probabilities, dim=-1).item() + return predicted_class + + def is_safe_file(self, filepath: str) -> bool: + """Check if the video file is safe.""" + video_data = read_video(filepath) + + # Sample frames at 2 FPS + sample_rate = 2 # frames per second + frame_interval = int(video_data.fps / sample_rate) + frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval)) + + is_safe = True + frame_scores = [] + + for frame_number in frame_numbers: + try: + frame = video_data.frames[frame_number] + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark the video as unsafe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + # Prepare data for JSON + video_data = { + "filepath": filepath, + "is_safe": is_safe, + "video_length": video_data.duration, + "fps": video_data.fps, + "frame_scores": frame_scores, + } + + log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.") + log.debug(f"Video data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe_frames(self, frames: Iterable) -> bool: + """Check if the video frames are safe.""" + is_safe = True + frame_scores = [] + + for frame_number, frame in enumerate(frames): + try: + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark as not safe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + video_data = { + "is_safe": is_safe, + "frame_scores": frame_scores, + } + + log.debug(f"Frames data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]: + if isinstance(input, str): + is_safe = self.is_safe_file(input) + return is_safe, "safe video detected" if is_safe else "unsafe video detected" + elif isinstance(input, Iterable): + is_safe = self.is_safe_frames(input) + return is_safe, "safe frames detected" if is_safe else "unsafe frames detected" + else: + raise ValueError(f"Input type {type(input)} not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Video Content Safety Filter checkpoint folder", + default=DEFAULT_CHECKPOINT_DIR, + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe") + + for filepath in filepaths: + with timer("video content safety filter"): + _ = runner.run_safety_check(filepath) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/video_content_safety_filter_model.py b/video_content_safety_filter_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ccc005d90bad4d029bf9fdc9a66450fa9f7049 --- /dev/null +++ b/video_content_safety_filter_model.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attrs +import torch +import torch.nn as nn + +from .config import make_freezable + + +@make_freezable +@attrs.define(slots=False) +class ModelConfig: + input_size: int = 1152 + num_classes: int = 7 + + +class SafetyClassifier(nn.Module): + def __init__(self, input_size: int = 1024, num_classes: int = 2): + super().__init__() + self.input_size = input_size + self.num_classes = num_classes + self.layers = nn.Sequential( + nn.Linear(self.input_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, self.num_classes), + # Note: No activation function here; CrossEntropyLoss expects raw logits + ) + + def forward(self, x): + return self.layers(x) + + +class VideoSafetyModel(nn.Module): + def __init__(self, config: ModelConfig) -> None: + super().__init__() + self.config = config + self.num_classes = config.num_classes + self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes) + + @torch.inference_mode() + def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + logits = self.network(data_batch["data"].cuda()) + return {"logits": logits} diff --git a/video_content_safety_filter_vision_encoder.py b/video_content_safety_filter_vision_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d73893a8b18ea3ff106b47a75c9123dba31cd9ba --- /dev/null +++ b/video_content_safety_filter_vision_encoder.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from PIL import Image +from transformers import SiglipModel, SiglipProcessor + +DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/video_content_safety_filter" + + +class SigLIPEncoder(torch.nn.Module): + def __init__( + self, + model_name: str = "google/siglip-so400m-patch14-384", + checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR, + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, + ) -> None: + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = dtype + self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def encode_image(self, input_img: Image.Image) -> torch.Tensor: + """Encode an image into a feature vector.""" + with torch.no_grad(): + inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype) + image_features = self.model.get_image_features(**inputs) + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features diff --git a/vit.py b/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff9bd018ae15cf88a341f465bb880c58cfd2f4f --- /dev/null +++ b/vit.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, +designed for processing image inputs in vision-language models. + +This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): +https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py +""" +from functools import partial +from typing import Any, Callable, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from .ar_modules_normalization import create_norm +from .ar_transformer import TransformerBlock +from .log import log + + +def get_vit_config(model_name: str) -> Mapping[str, Any]: + """ + Get the ViT configuration for a given model name. + """ + if model_name == "pixtral-12b-vit": + # The 400M ViT of Pixtral 12B VLM + return dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + else: + raise ValueError(f"Unknown model name: {model_name}") + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + Precompute 2D complex tensor for rotary position embedding. + + This function generates a 2D complex tensor used for rotary position embeddings, + which helps the model understand spatial relationships in the input image. + + Args: + dim (int): Dimension of the model (typically the hidden size divided by number of heads). + height (int): Height of the image in patches. + width (int): Width of the image in patches. + theta (float): Base value for the angle calculation, controls the frequency range. + + Returns: + torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting with input tensor. + + This function ensures that the frequency tensor can be properly broadcast + with the input tensor during the rotary embedding process. + + Args: + freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. + x (torch.Tensor): Input tensor to be embedded. + + Returns: + torch.Tensor: Reshaped frequency tensor ready for broadcasting. + """ + ndim = x.ndim + assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + *args, + freqs_cis: torch.Tensor, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary positional embeddings to input tensors. + + This function applies the rotary positional embeddings to the query and key tensors, + which helps the model understand spatial relationships in the input. + + Args: + xq (torch.Tensor): Query tensor. + xk (torch.Tensor): Key tensor. + freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. + *args: Variable length argument list (unused). + **kwargs: Arbitrary keyword arguments (unused). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class VisionTransformer(nn.Module): + """ + Vision Transformer model for image processing. + + This class implements a Vision Transformer that processes images using a patch-based approach + and applies transformer layers with rotary position embeddings. + + Args: + dim (int): Dimension of the model (hidden size). + num_channels (int): Number of input image channels (e.g., 3 for RGB). + patch_size (int): Size of each image patch (e.g., 16x16 pixels). + n_layers (int): Number of transformer layers. + n_heads (int): Number of attention heads. + ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. + norm_type (str): Type of normalization to use (e.g., "rmsnorm"). + norm_eps (float): Epsilon value for normalization layers. + image_size (int): Size of the input image (assumed square). + rope_theta (float): Base value for rotary position embedding calculation. + attention_dropout (float): Dropout rate for attention layers. + hidden_dropout (float): Dropout rate for hidden layers. + image_token_id (int): Token ID for the image token (if present). + """ + + def __init__( + self, + dim: int = 1024, + num_channels: int = 3, + patch_size: int = 16, + n_layers: int = 24, + n_heads: int = 16, + n_kv_heads: int = None, + ffn_hidden_size: int = 4096, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + image_size: int = 1024, + rope_theta: float = 1000000.0, + image_token_id: int = None, + ): + super().__init__() + self.patch_conv = nn.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) + if n_kv_heads is None: + n_kv_heads = n_heads + layer_args = dict( + n_layers=n_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dim=dim, + use_qk_normalization=False, + max_seq_len=None, + max_batch_size=None, + ffn_hidden_size=ffn_hidden_size, + norm_type=norm_type, + norm_eps=norm_eps, + causal_mask=False, # Full attention in ViT + head_dim=None, + insert_cross_attn=False, + attn_type="full", + ) + + self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) + + head_dim = dim // n_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + + self.dim = dim + self.n_heads = n_heads + self.max_patches_per_side = image_size // patch_size + self.image_size = image_size + self.patch_size = patch_size + self.rope_theta = rope_theta + self._freqs_cis: Optional[torch.Tensor] = None + self.image_token_id = image_token_id + + num_params = self.get_num_params() + log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") + + @classmethod + def build( + cls, + config: Mapping[str, Any], + ) -> "VisionTransformer": + """ + Create a Vision Transformer from a configuration dictionary. + + This class method creates a Vision Transformer from a configuration dictionary, + which is typically loaded from a JSON file or other configuration source. + + Args: + config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. + + Returns: + VisionTransformer: Vision Transformer model instance. + """ + necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] + missing_keys = [k for k in necessary_keys if k not in config] + assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" + return cls( + **config, + ) + + def expand_in_channels(self, new_in_channels: int): + """ + Expand the input channels of the patch convolution layer. + This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. + Note that you should only call this method after the weight is loaded. + """ + assert ( + new_in_channels > self.patch_conv.in_channels + ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." + log.debug( + f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." + ) + new_conv = nn.Conv2d( + in_channels=new_in_channels, + out_channels=self.patch_conv.out_channels, + kernel_size=self.patch_conv.kernel_size, + stride=self.patch_conv.stride, + bias=False, + ) + new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) + new_conv.weight.data[ + :, self.patch_conv.in_channels : + ].zero_() # zeroize, such that initially it has no effect to output + self.patch_conv = new_conv + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + """ + Get or compute the frequency tensor for rotary position embedding. + + This property lazily initializes and caches the frequency tensor used for + rotary position embeddings, ensuring it's on the correct device. + + Returns: + torch.Tensor: The frequency tensor for rotary position embeddings. + """ + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.dim // self.n_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the Vision Transformer. + + This method processes the input image through the Vision Transformer, + including patch embedding, position embedding, and transformer layers. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, + C is number of channels, and H, W are height and width. + + Returns: + torch.Tensor: Output features of shape (B, N, D), where N is the number of patches + and D is the embedding dimension. + """ + + patch_embeds = self.patch_conv(x) # (B, D, Hp, Wp) + _, _, Hp, Wp = patch_embeds.shape # Patch embeds dim + patch_embeds = patch_embeds.flatten(2) # (B, D, Hp*Wp) + patch_embeds = patch_embeds.transpose(1, 2) # (B, Hp*Wp, D) + patch_embeds = self.ln_pre(patch_embeds) # (B, Hp*Wp, D) + positions = torch.stack( + torch.meshgrid( + torch.arange(Hp), + torch.arange(Wp), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) + out = self.transformer(patch_embeds, rope=rope) + + return out + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + +class VisionTransformerBlocks(nn.Module): + """ + Vision Transformer Blocks. + + This class implements a stack of Transformer blocks used in the Vision Transformer. + + Args: + n_layers (int): Number of transformer layers. + args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, + """ + + def __init__( + self, + n_layers: int, + args: Mapping[str, Any], + ): + super().__init__() + self.layers = torch.nn.ModuleList() + + for layer_id in range(n_layers): + self.layers.append( + TransformerBlock( + layer_id=layer_id, + args=args, + ) + ) + + def forward( + self, + x: torch.Tensor, + rope: Callable, + ) -> torch.Tensor: + """ + Forward pass through the Vision Transformer Blocks. + + This method applies a series of Transformer blocks to the input tensor, + using the provided rotary position embedding function. + + Args: + x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, + N is the number of patches, and D is the embedding dimension. + rope (Callable): Rotary position embedding function to be applied in each layer. + + Returns: + torch.Tensor: Output tensor after passing through all transformer layers, + with the same shape as the input. + """ + for layer in self.layers: + x = layer(x, input_pos=None, mask=None, rope=rope) + return x diff --git a/world_generation_pipeline.py b/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7b49d543c293961c0db81717d3dee8a736c4f091 --- /dev/null +++ b/world_generation_pipeline.py @@ -0,0 +1,658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import Any, Optional + +import numpy as np +import torch + +from .base_world_generation_pipeline import BaseWorldGenerationPipeline +from .inference_utils import ( + generate_world_from_text, + generate_world_from_video, + get_condition_latent, + get_video_batch, + load_model_by_config, + load_network_model, + load_tokenizer_model, +) +from .model_t2w import DiffusionT2WModel +from .model_v2w import DiffusionV2WModel +from .text2world_prompt_upsampler_inference import ( + create_prompt_upsampler, + run_chat_completion, +) +from .video2world_prompt_upsampler_inference import ( + create_vlm_prompt_upsampler, + prepare_dialog, +) +from .video2world_prompt_upsampler_inference import ( + run_chat_completion as run_chat_completion_vlm, +) +from .log import log + +MODEL_NAME_DICT = { + "Cosmos-1.0-Diffusion-7B-Text2World": "Cosmos_1_0_Diffusion_Text2World_7B", + "Cosmos-1.0-Diffusion-14B-Text2World": "Cosmos_1_0_Diffusion_Text2World_14B", + "Cosmos-1.0-Diffusion-7B-Video2World": "Cosmos_1_0_Diffusion_Video2World_7B", + "Cosmos-1.0-Diffusion-14B-Video2World": "Cosmos_1_0_Diffusion_Video2World_14B", +} + + +class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + enable_text_guardrail: bool = True, + enable_video_guardrail: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + ): + """Initialize the diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + enable_text_guardrail: Whether to enable text guardrail + enable_video_guardrail: Whether to enable video guardrail + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + """ + assert inference_type in [ + "text2world", + "video2world", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + self.model_name = MODEL_NAME_DICT[checkpoint_name] + self.guidance = guidance + self.num_steps = num_steps + self.height = height + self.width = width + self.fps = fps + self.num_video_frames = num_video_frames + self.seed = seed + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + enable_text_guardrail=enable_text_guardrail, + enable_video_guardrail=enable_video_guardrail, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_guardrail_models=offload_guardrail_models, + ) + self.prompt_upsampler_dir = prompt_upsampler_dir + self.enable_prompt_upsampler = enable_prompt_upsampler + self.offload_prompt_upsampler = offload_prompt_upsampler + + self.prompt_upsampler = None + if enable_prompt_upsampler and not offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="df_config_config.py", + model_class=DiffusionT2WModel, + ) + + def _load_network(self): + load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}/model.pt") + + def _load_tokenizer(self): + load_tokenizer_model(self.model, f"{self.checkpoint_dir}/Cosmos-1.0-Tokenizer-CV8x8x8") + + def _offload_prompt_upsampler_model(self): + """Move prompt enhancement model to CPU/disk. + + Offloads prompt upsampling model after processing input + to reduce GPU memory usage. + """ + if self.prompt_upsampler: + del self.prompt_upsampler + self.prompt_upsampler = None + gc.collect() + torch.cuda.empty_cache() + + def _run_prompt_upsampler_on_prompt(self, prompt: str) -> str: + """Enhance the input prompt using the prompt upsampler model. + + Args: + prompt: Raw text prompt to be enhanced + + Returns: + str: Enhanced version of the input prompt with more descriptive details + """ + upsampled_prompt = run_chat_completion(self.prompt_upsampler, prompt) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _run_prompt_upsampler_on_prompt_with_offload(self, *args: Any, **kwargs: Any) -> str: + """Enhance prompt with prompt upsampler model. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Enhanced prompt string + """ + if self.offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + enhanced_prompt = self._run_prompt_upsampler_on_prompt(*args, **kwargs) + + if self.offload_prompt_upsampler: + self._offload_prompt_upsampler_model() + + return enhanced_prompt + + def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: + """Decode latent samples to video frames using the tokenizer decoder. + + Args: + sample: Latent tensor from diffusion model [B, C, T, H, W] + + Returns: + np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] + with values in range [0, 255] + """ + # Decode video + video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] + video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + + return video + + def _run_model( + self, + embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Generate video latents using the diffusion model. + + Args: + embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + torch.Tensor: Generated video latents before tokenizer decoding + + Note: + The model and tokenizer are automatically offloaded after inference + if offloading is enabled in the config. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + sample = generate_world_from_text( + model=self.model, + state_shape=state_shape, + is_negative_prompt=True if negative_prompt_embedding is not None else False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + ) + + return sample + + def _run_model_with_offload( + self, prompt_embedding: torch.Tensor, negative_prompt_embedding: Optional[torch.Tensor] = None + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + *args: Positional arguments passed to _run_model + **kwargs: Keyword arguments passed to _run_model + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_network: + self._load_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_model(prompt_embedding, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + return sample + + def generate( + self, + prompt: str, + negative_prompt: Optional[str] = None, + word_limit_to_skip_upsampler: Optional[int] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + negative_prompt: Optional text to guide what not to generate + word_limit_to_skip_upsampler: Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if self.enable_text_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + # Enhance prompt + if self.enable_prompt_upsampler: + word_count = len(prompt.split()) + if word_limit_to_skip_upsampler is None or word_count <= word_limit_to_skip_upsampler: + log.info("Run prompt upsampler on prompt") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(prompt) + if self.enable_text_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt=prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + else: + log.info( + f"Skip prompt upsampler for better robustness because the number of words ({word_count}) in the prompt is greater than {word_limit_to_skip_upsampler}" + ) + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + log.info("Finish generation") + + if self.enable_video_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + +class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + enable_text_guardrail: bool = True, + enable_video_guardrail: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + num_input_frames: int = 1, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + enable_text_guardrail: Whether to enable text guardrail + enable_video_guardrail: Whether to enable video guardrail + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + enable_text_guardrail=enable_text_guardrail, + enable_video_guardrail=enable_video_guardrail, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + ) + + def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str: + """Enhance the input prompt using visual context from the conditioning image. + + Args: + image_or_video_path: Path to conditioning image or video used for visual context + + Returns: + str: Enhanced prompt incorporating visual details from the image + """ + dialog = prepare_dialog(image_or_video_path) + upsampled_prompt = run_chat_completion_vlm( + self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False + ) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_vlm_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="df_config_config.py", + model_class=DiffusionV2WModel, + ) + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + ) -> torch.Tensor: + """Generate video frames using the diffusion model. + + Args: + embedding: Text embedding tensor from T5 encoder + condition_latent: Latent tensor from conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + Tensor of generated video frames + + Note: + Model and tokenizer are automatically offloaded after inference + if offloading is enabled. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=self.model.state_shape, + is_negative_prompt=True, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video + + def _run_tokenizer_encoding(self, image_or_video_path: str) -> torch.Tensor: + """ + Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent = get_condition_latent( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=self.model.state_shape, + ) + + return condition_latent + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + condition_latent = self._run_tokenizer_encoding(image_or_video_path) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def generate( + self, + prompt: str, + image_or_video_path: str, + negative_prompt: Optional[str] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_or_video_path: Path to conditioning image or video + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with image or video path: {image_or_video_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if self.enable_text_guardrail and not self.enable_prompt_upsampler: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + # Enhance prompt + if self.enable_prompt_upsampler: + log.info("Run prompt upsampler on image or video, input prompt is not used") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path) + if self.enable_text_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_or_video_path, + ) + log.info("Finish generation") + + if self.enable_video_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt