diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..483b73b14d1fa97c3bc415607891042fc13e4fa6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..8c04b274d3383afcc166392c28708271b99fea75 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,10 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns +- Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support) \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c97e7ed36f67cd1b9d3d8cf465e0ec35661c38f1 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +This project welcomes contributions and suggestions. Most contributions require you to +agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the +instructions provided by the bot. You will only need to do this once across all repositories using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..e7a447c1a46121b740823965a7e54e5f7b3e8504 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,96 @@ +# MICROSOFT RESEARCH LICENSE TERMS + +**IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.** + +These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS. + +## 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS. + +Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes: + +a) **Source Code.** If source code is included, you may use and modify the source code, but you may not distribute the source code. + +b) **Object Code.** If object code is included, you may use the object code, but you may not distribute the object code. + +c) **Models.** If machine learning model(s) are included, you may use the model(s), but you may not distribute the models. + +d) **Data.** If data is included, you may use the data, but your use must be consistent with the consent under which the data was provided and/or gathered and you may not modify or distribute the data. + +## 2) SCOPE OF LICENSE. + +The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to): + +a) Work around any technical limitations in the Materials that only allow you to use it in certain ways; + +b) Reverse engineer, decompile or disassemble the Materials; + +c) Remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials; + +d) Use the Materials in any way that is against the law or to create or propagate malware; or + +e) Share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party. + +## 3) PERSONAL DATA. + +If the data (set forth in Section 1(d) above) includes or is found to include any data that enables any ability to identify an individual ("Personal Data"), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, **immediately upon the completion of your research.** + +## 4) LICENSE TO MICROSOFT. + +Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose. + +## 5) PUBLICATION. + +You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation. + +## 6) FEEDBACK. + +Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. **Additional** Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above. + +## 7) COMPLIANCE WITH TRADE LAWS. + +You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting. + +## 8) SUPPORT SERVICES. + +Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + +## 9) BINDING ARBITRATION AND CLASS ACTION WAIVER. + +**This Section applies if you live in (or, if a business, your principal place of business is in) the United States.** If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to **binding individual arbitration before the American Arbitration Association** under the Federal Arbitration Act ("FAA"), and not to **sue in court in front of a judge or jury.** Instead, a neutral arbitrator will decide. **Class action lawsuits, class-wide arbitrations, private attorney-general actions,** and any other proceeding where someone acts in a representative capacity **are not allowed;** nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms. + +## 10) ENTIRE AGREEMENT. + +This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials. + +## 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. + +If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration). + +## 12) CONSUMER RIGHTS; REGIONAL VARIATIONS. + +This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you: + +a) **Australia.** You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights. + +b) **Canada.** If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software. + +c) **Germany and Austria.** + i. **Warranty.** The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software. + ii. **Limitation of Liability.** In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law. + +Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence. + +## 13) DISCLAIMER OF WARRANTY. + +THE MATERIALS ARE LICENSED "AS IS." YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. + +## 14) LIMITATION ON AND EXCLUSION OF DAMAGES. + +IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES. + +This limitation applies to: +- (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and +- (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law. + +It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages. + diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..7d08146a631114a9a1628fa9be5e1a76f9af863d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,37 @@ +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). \ No newline at end of file diff --git a/assets/Demonstrator/Fig_01.png b/assets/Demonstrator/Fig_01.png new file mode 100644 index 0000000000000000000000000000000000000000..cf98b3204693dfc5846304d2e8e15cc7d161c46a Binary files /dev/null and b/assets/Demonstrator/Fig_01.png differ diff --git a/assets/Demonstrator/Fig_02.png b/assets/Demonstrator/Fig_02.png new file mode 100644 index 0000000000000000000000000000000000000000..249503603c9b7651bd92dca8f7e9eb2e3b7f8615 Binary files /dev/null and b/assets/Demonstrator/Fig_02.png differ diff --git a/assets/Demonstrator/Fig_03.png b/assets/Demonstrator/Fig_03.png new file mode 100644 index 0000000000000000000000000000000000000000..6530e5c0d9a879ab85e8267b977c3037a2b13939 Binary files /dev/null and b/assets/Demonstrator/Fig_03.png differ diff --git a/assets/Demonstrator/Fig_04.png b/assets/Demonstrator/Fig_04.png new file mode 100644 index 0000000000000000000000000000000000000000..2db5d14c77c47a0aa5b87da50bee9cc314432a45 Binary files /dev/null and b/assets/Demonstrator/Fig_04.png differ diff --git a/assets/Demonstrator/Fig_05.png b/assets/Demonstrator/Fig_05.png new file mode 100644 index 0000000000000000000000000000000000000000..72b6c3ed269770fdd157a40710acdc3c59a20617 Binary files /dev/null and b/assets/Demonstrator/Fig_05.png differ diff --git a/assets/Demonstrator/Fig_06.png b/assets/Demonstrator/Fig_06.png new file mode 100644 index 0000000000000000000000000000000000000000..2202e663088573436f3eefebc03a9c38649d0d8f Binary files /dev/null and b/assets/Demonstrator/Fig_06.png differ diff --git a/assets/Demonstrator/Fig_07.png b/assets/Demonstrator/Fig_07.png new file mode 100644 index 0000000000000000000000000000000000000000..0eadcc8639360772a2fd24c3bfe8b2d54124f0f6 Binary files /dev/null and b/assets/Demonstrator/Fig_07.png differ diff --git a/assets/Demonstrator/Fig_08.png b/assets/Demonstrator/Fig_08.png new file mode 100644 index 0000000000000000000000000000000000000000..bbfaeba2e1cd332705a2ce9120802228d661d202 Binary files /dev/null and b/assets/Demonstrator/Fig_08.png differ diff --git a/assets/Demonstrator/Fig_09.png b/assets/Demonstrator/Fig_09.png new file mode 100644 index 0000000000000000000000000000000000000000..d186776ee5d296483aa5d69bbc77b708ff489da5 Binary files /dev/null and b/assets/Demonstrator/Fig_09.png differ diff --git a/assets/Demonstrator/Fig_10.png b/assets/Demonstrator/Fig_10.png new file mode 100644 index 0000000000000000000000000000000000000000..bf8b95d3a184d6e5217ba0f8f7654d990390ba9f Binary files /dev/null and b/assets/Demonstrator/Fig_10.png differ diff --git a/assets/Demonstrator/Fig_11.png b/assets/Demonstrator/Fig_11.png new file mode 100644 index 0000000000000000000000000000000000000000..2392be3aa9dd92b062ec21f7ee40aa4be0fefaa9 Binary files /dev/null and b/assets/Demonstrator/Fig_11.png differ diff --git a/assets/Demonstrator/Fig_12.png b/assets/Demonstrator/Fig_12.png new file mode 100644 index 0000000000000000000000000000000000000000..1699bb254882033e9e4a92379c27882cbc2a61bb Binary files /dev/null and b/assets/Demonstrator/Fig_12.png differ diff --git a/assets/Demonstrator/Fig_13.png b/assets/Demonstrator/Fig_13.png new file mode 100644 index 0000000000000000000000000000000000000000..bfccc7628ad75f236f3b6dd6218c7556c733842a Binary files /dev/null and b/assets/Demonstrator/Fig_13.png differ diff --git a/assets/Demonstrator/Fig_14.png b/assets/Demonstrator/Fig_14.png new file mode 100644 index 0000000000000000000000000000000000000000..c57d92e5a40201239694db2d527367c9cb950edd Binary files /dev/null and b/assets/Demonstrator/Fig_14.png differ diff --git a/assets/Demonstrator/Fig_15.png b/assets/Demonstrator/Fig_15.png new file mode 100644 index 0000000000000000000000000000000000000000..01c1830ff09e3ce17856d217f0e4f08a6805baaa Binary files /dev/null and b/assets/Demonstrator/Fig_15.png differ diff --git a/assets/Readme/model_capabilities.gif b/assets/Readme/model_capabilities.gif new file mode 100644 index 0000000000000000000000000000000000000000..657c6eadb37a800e06da59fff349197901a24a02 --- /dev/null +++ b/assets/Readme/model_capabilities.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87cf1460b2779a1c85b70e2229a7e1e256c501a5e3db26ea74e445b9dc75e965 +size 8629032 diff --git a/assets/Readme/wham_gen_1.gif b/assets/Readme/wham_gen_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..c2e592d8c9c4bf5c8dbcb17f4cfc7ae264dd2b1e --- /dev/null +++ b/assets/Readme/wham_gen_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96558d0ad8084eafaf60ee360f13fe8decfbc5ac737b0c2788c01310e81750d1 +size 4415495 diff --git a/assets/Readme/wham_gen_2.gif b/assets/Readme/wham_gen_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..eb503adb078ab7027d2bf86b994238de86a9cc99 --- /dev/null +++ b/assets/Readme/wham_gen_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1296bb4ccdac5c7d3a1e7e9adfc48a6ec255933ff252a31d4e45cd117a28aee7 +size 4150372 diff --git a/assets/Readme/wham_gen_3.gif b/assets/Readme/wham_gen_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..cc562cfbdbcef859e4ca5fcb16d14daa0d5bceeb --- /dev/null +++ b/assets/Readme/wham_gen_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb8ea8b3d6c8ec737a9b03f4cd93aeb36ddddc33695849b9b83543a8c2242b6f +size 4267917 diff --git a/assets/Readme/wham_gen_4.gif b/assets/Readme/wham_gen_4.gif new file mode 100644 index 0000000000000000000000000000000000000000..17b8d01d3d103c52b3a05abf9362079ba8f7e0fa --- /dev/null +++ b/assets/Readme/wham_gen_4.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45e895599dddae5e6d2eb31f66957726fb82662f41b149f4de206466083f5a42 +size 4296210 diff --git a/assets/Readme/wham_gen_5.gif b/assets/Readme/wham_gen_5.gif new file mode 100644 index 0000000000000000000000000000000000000000..ca33aad575050da816c7e822de969c0b3947e9fa --- /dev/null +++ b/assets/Readme/wham_gen_5.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7e7675c737bf5cbdfb54dfcc568eeda4c4212dbe5726741205610ab29cfcabb +size 4243973 diff --git a/assets/Readme/wham_gen_6.gif b/assets/Readme/wham_gen_6.gif new file mode 100644 index 0000000000000000000000000000000000000000..8d8702575ab2afae0fd3f48e14d42cb5884449c7 --- /dev/null +++ b/assets/Readme/wham_gen_6.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e536b1f88a92de4e116a6acd022987778f63ed5a841517758c14a0d7f2a3c2bd +size 4085542 diff --git a/assets/Readme/wham_gen_7.gif b/assets/Readme/wham_gen_7.gif new file mode 100644 index 0000000000000000000000000000000000000000..d5b37ed5fcfee7a9ee74a589fc0f8d587fe9354d --- /dev/null +++ b/assets/Readme/wham_gen_7.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb7e6c63eb8c46fc8c824d93406550082b6532ea9473cd021bae72a7d6cbe7db +size 4129684 diff --git a/assets/Readme/wham_gen_8.gif b/assets/Readme/wham_gen_8.gif new file mode 100644 index 0000000000000000000000000000000000000000..46bbc7479eccce12d33495fd0dc60a0edb2f7176 --- /dev/null +++ b/assets/Readme/wham_gen_8.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:366f3f92310f3cfa55c9f4da719b01c8399c42f7d7bb860c5f7153568e4991d5 +size 3980988 diff --git a/assets/Readme/wham_gen_9.gif b/assets/Readme/wham_gen_9.gif new file mode 100644 index 0000000000000000000000000000000000000000..56270ee8f944d14ad3acf8a07689b8b165deef57 --- /dev/null +++ b/assets/Readme/wham_gen_9.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:931713a1d9a9dbdef7b4a1821ef78d490282bf8475e65b39948f8b5f42dc9982 +size 4526755 diff --git a/configs/metadata_custom_tag.config b/configs/metadata_custom_tag.config new file mode 100644 index 0000000000000000000000000000000000000000..7e25921318e5a752970f896da15c09e3b018c713 --- /dev/null +++ b/configs/metadata_custom_tag.config @@ -0,0 +1,5 @@ +%Image::ExifTool::UserDefined = ( + 'Image::ExifTool::XMP::xmp' => { + 'ProgramName' => { Name => 'ProgramName', Writable => 'string' } + } +); \ No newline at end of file diff --git a/models/WHAM_1.6B_v1.ckpt b/models/WHAM_1.6B_v1.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..ffe3e3f689be13555425b4858d4a00f57d8f08f6 --- /dev/null +++ b/models/WHAM_1.6B_v1.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c4997074883aa1a39a5994a7dea91fb62b2382fc039523458827adb777af8e9 +size 20339650059 diff --git a/models/WHAM_200M.ckpt b/models/WHAM_200M.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..140533fc3cfd1afa848dcfabb6afeee770227659 --- /dev/null +++ b/models/WHAM_200M.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ddb8e03a33f0849a63da030fea3de4994d95e16888993b8ab92faa904f3b31f +size 3980245067 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..46d9e08fd5c5b3890d72e253f4d4db47233c7d80 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,48 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +aiohttp==3.9.3 +aiosignal==1.3.1 +async-timeout==4.0.3 +attrs==23.2.0 +blinker==1.7.0 +certifi==2024.2.2 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +cmake==3.28.3 +einops==0.6.0 +ffmpegcv==0.3.10 +filelock==3.13.1 +Flask==3.0.2 +frozenlist==1.4.1 +fsspec==2024.2.0 +idna==3.6 +importlib_metadata==7.0.2 +itsdangerous==2.1.2 +Jinja2==3.1.3 +lightning-utilities==0.10.1 +lit==17.0.6 +MarkupSafe==2.1.5 +mpmath==1.3.0 +multidict==6.0.5 +networkx==3.2.1 +numpy==1.25.2 +opencv-python==4.6.0.66 +opencv-python-headless==4.9.0.80 +packaging==23.2 +pillow==10.2.0 +pytorch-lightning==1.9.4 +PyYAML==6.0.1 +requests==2.31.0 +sympy==1.12 +tensordict==0.1.2 +torch==2.0.1+cu118 +torchinfo==1.7.1 +torchmetrics==0.11.4 +torchvision==0.15.2+cu118 +tqdm==4.66.2 +triton==2.0.0 +typing_extensions==4.10.0 +urllib3==2.2.1 +Werkzeug==3.0.1 +yarl==1.9.4 +zipp==3.17.0 diff --git a/run_dreaming.py b/run_dreaming.py new file mode 100644 index 0000000000000000000000000000000000000000..2e189376d727e01383be80e3c4690ef3f3a0d302 --- /dev/null +++ b/run_dreaming.py @@ -0,0 +1,264 @@ +""" +Example script for running dreaming on a dataset. +The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context. + +After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players), +should be identical if model was ideal. +""" + +import argparse +from pathlib import Path +import os +import subprocess + +import cv2 +from tensordict import TensorDict +import torch as th +from tqdm import tqdm +import numpy as np +import ffmpegcv +from PIL import Image + +import wham.utils as utils + + +parser = argparse.ArgumentParser(description="Run dreaming.") +parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.") +parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.") +parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.") +parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.") +parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.") + + +parser.add_argument( + "--protocol", + type=str, + default="base", + choices=["base", "comprehensive"], + help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.", +) +parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.") +parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.") +parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.") + +parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.") +parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.") +parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.") + + +def get_context_data(image_context, action_context, action_sequences): + # Make sure we have CHW images: + assert image_context.shape[-3] == 3, "Image context should be CHW" + + image_context = th.from_numpy(image_context).cuda() + action_data = th.from_numpy(action_context).float().cuda() + action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None + + return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2]) + + +def add_video_metadata(file_path, metadata_config): + # Construct the exiftool command + cmd = [ + 'exiftool', + '-config', metadata_config, + f'-ProgramName=\"{utils.PROGRAM_NAME}\"', + '-overwrite_original', + file_path + ] + + try: + # Execute the exiftool command + subprocess.run(cmd, check=True) + print(f"Metadata modified successfully.") + # Print the new file metadata + cmd_output = [ + 'exiftool', + file_path + ] + subprocess.run(cmd_output, check=True) + except subprocess.CalledProcessError as e: + print(f"Error modifying metadata: {e}") + + +@th.no_grad() +def do_dreaming(model, image_context, action_context, args, action_sequences=None): + """ + image_contect and action_context provide the initial context for the model to dream from. + + If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions. + """ + context_data = get_context_data(image_context, action_context, action_sequences) + encoded_context_data = model.encode_context(context_data) + + encoded_action_sequences = None + if action_sequences is not None: + assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)" + action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda() + encoded_action_sequences = model.encode_context(action_sequences) + + encoded_dreamt_steps = [] + + for dream_step in range(args.steps_to_dream): + encoded_predicted_step, _ = model.predictor.predict_next_step( + encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1 + ) + + # Remove first step from context if we are at the max context length: + if encoded_context_data.shape[1] == args.context_length: + encoded_context_data = encoded_context_data[:, 1:] + + # Add predicted image + action to the context + append_step = encoded_predicted_step + if encoded_action_sequences is not None: + # Replace predicted action with real action + append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :] + encoded_context_data = th.cat((encoded_context_data, append_step), dim=1) + + encoded_dreamt_steps.append(encoded_predicted_step) + + # Decode everything + dreamed_images = [] + actions_during_dream = [] + for seq_i in range(args.steps_to_dream): + decoded_step = model.decode_context(encoded_dreamt_steps[seq_i]) + dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy()) + actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy()) + + dreamed_images = np.concatenate(dreamed_images, axis=1) + actions_during_dream = np.concatenate(actions_during_dream, axis=1) + + return dreamed_images, actions_during_dream + + +@th.no_grad() +def encode_decode_images(model, images): + """ + Pass ground_truth images through the encoding/decoding process of the model. + """ + context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2]) + output_images = [] + for seq_i in range(images.shape[1]): + encoded_images = model.encode_context(context[:, [seq_i]]) + decoded_images = model.decode_context(encoded_images) + output_images.append(decoded_images["images"].cpu().numpy()) + return np.concatenate(output_images, axis=1) + + +def main(args): + total_video_length = args.context_length + args.steps_to_dream + + # Now, load the model: + model_path = Path(args.model_path) + assert model_path.is_file(), "Could not find the model!" + model = utils.load_model_from_checkpoint(model_path).cuda() + + # Glob the dataset to find all the ground truth segments we want to construct a dream for: + data_path = Path(args.data_path) + ground_truth_files = list(data_path.rglob("*.npz")) + num_dreams = len(ground_truth_files) + + if args.max_files is not None: + # Sort to make sure we always get the same files + ground_truth_files = sorted(ground_truth_files) + ground_truth_files = ground_truth_files[: args.max_files] + num_dreams = len(ground_truth_files) + + output_path = Path(args.output) + os.makedirs(output_path, exist_ok=True) + + print("=" * 100) + print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS") + print(f"WRITING TO {args.output}") + print("=" * 100) + + dreams_created = 0 + with tqdm(total=num_dreams, desc="Dreams") as pbar: + while ground_truth_files: + # Load batch_size headers: + batches = min(args.batch_size, len(ground_truth_files)) + batched_image_context = [] + batched_image_sequence = [] + batched_action_context = [] + batched_action_sequence = [] + episode_names = [] + for i in range(batches): + episode = ground_truth_files.pop() + episode_names.append(episode) + try: + data = np.load(episode) + images = data["images"] + actions = data["actions"] + except Exception: + print(f"Failed to load episode {episode} - skipping.") + continue + + if actions.shape[0] < total_video_length: + # We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough + raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.") + batched_image_context.append(images[: args.context_length]) + batched_image_sequence.append(images[args.context_length: total_video_length]) + batched_action_context.append(actions[: args.context_length]) + batched_action_sequence.append(actions[args.context_length: total_video_length]) + + image_context = np.array(batched_image_context) + image_sequences = np.array(batched_image_sequence) + action_context = np.array(batched_action_context) + action_sequences = np.array(batched_action_sequence) + + if args.protocol == "comprehensive": + # We do not need to pass in the action sequences for comprehensive protocol + action_sequences = None + + full_image_sequence = np.concatenate((image_context, image_sequences), axis=1) + + dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences) + encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence) + + pbar.update(batches) + dreams_created += batches + + # Save the dreams: + # We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names + # but make them relative to our output folder: + for i, dream in enumerate(dreamt_images): + episode = episode_names[i] + output_file = output_path / episode.relative_to(data_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + np.savez( + output_file, + context_length=args.context_length, + steps_to_dream=args.steps_to_dream, + raw_context=image_context[i], + dreamt_images=dream, + all_actions=np.concatenate((action_context[i], actions_during_dream[i])), + encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i], + ) + + video_file = str(output_file.with_suffix(".mp4")) + writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS) + full_sequence = np.concatenate((image_context[i], dream), axis=0) + for frame in full_sequence: + img = frame.transpose(1, 2, 0).astype(np.uint8).copy() + # Please DO NOT remove this watermark. This will infringe upon the repo's license agreement + (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS) + x = img.shape[1] - text_width - 10 # 10 pixels from the right edge + y = img.shape[0] - 10 # 10 pixels from the bottom edge + cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS) + + # Add image metadata + pil_image = Image.fromarray(img) + pil_image.info['Id'] = 0x0131 + pil_image.info['Type'] = 2 + pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8") + pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1 + + # Convert pil_image to a CV2 format for the video writer + cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + writer.write(cv_image) + writer.release() + add_video_metadata(video_file, args.metadata_config) + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/run_server.py b/run_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3e83bad2c1ddca434ed53353ba50879876fed388 --- /dev/null +++ b/run_server.py @@ -0,0 +1,519 @@ +import argparse +from dataclasses import dataclass, field +import json +import copy +import multiprocessing as mp +import uuid +from datetime import datetime, timedelta +from collections import defaultdict, deque +import io +import zipfile +import queue +import time +import random +import logging + +from tensordict import TensorDict +import cv2 +from flask import Flask, request, make_response, send_file +from PIL import Image +import torchvision.transforms as T +import numpy as np +import torch as th + +from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE + +logging.basicConfig(level=logging.INFO) + +parser = argparse.ArgumentParser(description="Simple Dreamer") +parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs") +parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.") +parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one") +parser.add_argument("--port", type=int, default=5000) + +parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.") +parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.") +parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.") + +parser.add_argument("--image_width", type=int, default=300, help="Width of the image") +parser.add_argument("--image_height", type=int, default=180, help="Height of the image") + +parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers") + +PREDICTION_JSON_FILENAME = "predictions.json" +# Minimum time between times we check when to delete jobs. We do this when adding new jobs. +JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10) + +MAX_CANCELLED_ID_QUEUE_SIZE = 100 + +DEFAULT_SAMPLING_SETTINGS = { + "temperature": 0.9, + "top_k": None, + "top_p": 1.0, + "max_context_length": 10, +} + + +def float_or_none(string): + if string.lower() == "none": + return None + return float(string) + + +def be_image_preprocess(image, target_width, target_height): + # If target_width and target_height are specified, resize the image. + if target_width is not None and target_height is not None: + # Make sure we do not try to resize if the image is already the correct size. + if image.shape[1] != target_width or image.shape[0] != target_height: + image = cv2.resize(image, (target_width, target_height)) + return np.transpose(image, (2, 0, 1)) + + +def action_vector_to_be_action_vector(action): + # Preprocess a BE action vector from 16 numbers with: + # 12 buttons [0, 1] and 4 stick directions [-1, 1] + # to discrete actions valid for the token model + # 12 buttons [0, 1] and 4 stick directions {discrete bin} + action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1 + return action + + +def be_action_vector_to_action_vector(action): + # Preprocess a BE action vector into unified space + for stick_index in range(-4, 0): + action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])] + return action + + + +@dataclass +class DreamJob: + job_id: str + sampling_settings: dict + num_predictions_remaining: int + num_predictions_done: int + # (B, T, C, H, W) + context_images: th.Tensor + context_actions: th.Tensor + # Tokens that will replace the context_images if they are provided + context_tokens: list + # This will replace the dreamed action if provided. + # For every step, we remove the first action until exhausted + actions_to_take: th.Tensor = None + + +@dataclass +class DreamJobResult: + job_id: str + dream_step_index: int + # (B, 1, C, H, W) + dreamt_image: th.Tensor + dreamt_action: th.Tensor + dreamt_tokens: th.Tensor + result_creation_time: datetime = field(default_factory=datetime.now) + + + +def setup_and_load_model_be_model(args): + model = load_model_from_checkpoint(args.model) + th.set_float32_matmul_precision("high") + th.backends.cuda.matmul.allow_tf32 = True + return model + + +def get_job_batchable_information(job): + """Return comparable object of job information. Used for batching""" + context_length = job.context_images.shape[1] + return (context_length, job.sampling_settings) + + +def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1): + """Return a list of jobs (or empty list) that can be batched together""" + batchable_jobs = [] + required_job_info = None + while len(batchable_jobs) < max_batch_size: + try: + job = job_queue.get(timeout=timeout) + except queue.Empty: + break + # If pipe breaks, also gracefully return + except OSError: + break + if job.job_id in cancelled_ids_set: + # This job was cancelled, so discard it completely + continue + job_info = get_job_batchable_information(job) + if required_job_info is None: + required_job_info = job_info + elif required_job_info != job_info: + # This job is not batchable, put it back + job_queue.put(job) + # we assume here that, generally, the others jobs would also be + # invalid. So we just return the batchable jobs we have instead + # of going through more. + break + batchable_jobs.append(job) + return batchable_jobs + + +def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set): + """IN-PLACE Update cancelled_ids_set with new ids from the queue""" + has_changed = False + while not cancelled_ids_queue.empty(): + try: + cancelled_id = cancelled_ids_queue.get_nowait() + except queue.Empty: + break + cancelled_ids_deque.append(cancelled_id) + has_changed = True + + if has_changed: + cancelled_ids_set.clear() + cancelled_ids_set.update(cancelled_ids_deque) + + +def predict_step(context_data, sampling_settings, model, tokens=None): + with th.no_grad(): + predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings) + return predicted_step + + +def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args): + logger = logging.getLogger(f"dreamer_worker {device_to_use}") + logger.info("Loading up model...") + model = setup_and_load_model_be_model(args) + model = model.to(device_to_use) + logger.info("Model loaded. Fetching results") + + cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE) + cancelled_ids_set = set() + + while not quit_flag.is_set(): + update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set) + batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size) + if len(batchable_jobs) == 0: + continue + sampling_settings = batchable_jobs[0].sampling_settings + # make better way for passing these arguments around. sampling_settings + # is passed as kwargs to predicting step, but max_context_length is not part of valid + # keys there, so we need to pop it out. + max_context_length = sampling_settings.pop("max_context_length") + + images = [job.context_images[:, :max_context_length] for job in batchable_jobs] + actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs] + tokens = [job.context_tokens for job in batchable_jobs] + + images = th.concat(images, dim=0).to(device_to_use) + actions = th.concat(actions, dim=0).to(device_to_use) + + context_data = TensorDict({ + "images": images, + "actions_output": actions + }, batch_size=images.shape[:2]) + + predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens) + + predicted_step = predicted_step.cpu() + predicted_images = predicted_step["images"] + predicted_actions = predicted_step["actions_output"] + predicted_image_tokens = predicted_image_tokens.cpu() + + for job_i, job in enumerate(batchable_jobs): + image_context = job.context_images + action_context = job.context_actions + token_context = job.context_tokens + # Keep batch dimension + dreamt_image = predicted_images[job_i].unsqueeze(0) + dreamt_action = predicted_actions[job_i].unsqueeze(0) + dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0) + + # Replace the dreamed action if provided + actions_to_take = job.actions_to_take + if actions_to_take is not None and actions_to_take.shape[1] > 0: + dreamt_action = actions_to_take[:, 0:1] + # Remove the action we took + actions_to_take = actions_to_take[:, 1:] + if actions_to_take.shape[1] == 0: + actions_to_take = None + + result_queue.put(DreamJobResult( + job_id=job.job_id, + dream_step_index=job.num_predictions_done, + dreamt_image=dreamt_image, + dreamt_action=dreamt_action, + dreamt_tokens=dreamt_tokens + )) + + # Add job back in the queue if we have more steps to do + if job.num_predictions_remaining > 0: + # Stack the dreamt image and action to the context + if image_context.shape[1] >= max_context_length: + image_context = image_context[:, 1:] + action_context = action_context[:, 1:] + token_context = token_context[1:] + image_context = th.cat([image_context, dreamt_image], dim=1) + action_context = th.cat([action_context, dreamt_action], dim=1) + token_context.append(dreamt_tokens[0, 0].tolist()) + # We need to add context length back to sampling settings... + # add some better way of passing these settings around + job.sampling_settings["max_context_length"] = max_context_length + job_queue.put(DreamJob( + job_id=job.job_id, + sampling_settings=job.sampling_settings, + num_predictions_remaining=job.num_predictions_remaining - 1, + num_predictions_done=job.num_predictions_done + 1, + context_images=image_context, + context_actions=action_context, + context_tokens=token_context, + actions_to_take=actions_to_take + )) + + +class DreamerServer: + def __init__(self, num_workers, args): + self.num_workers = num_workers + self.args = args + self.model = None + self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs) + self.results_queue = mp.Queue() + self.cancelled_jobs = set() + self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)] + # job_id -> results + self._last_result_cleanup = datetime.now() + self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan) + self.local_results = defaultdict(list) + self.logger = logging.getLogger("DreamerServer") + + def get_details(self): + details = { + "model_file": self.args.model, + "max_concurrent_jobs": self.args.max_concurrent_jobs, + "max_dream_steps_per_job": self.args.max_dream_steps_per_job, + "max_job_lifespan": self.args.max_job_lifespan, + } + return json.dumps(details) + + def _check_if_should_remove_old_jobs(self): + time_now = datetime.now() + # Only cleanup every JOB_CLEANUP_CHECK_RATE seconds at most + if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE: + return + + self._last_result_cleanup = time_now + # First add existing results to the local results + self._gather_new_results() + # Check if we should remove old jobs + job_ids = list(self.local_results.keys()) + for job_id in job_ids: + results = self.local_results[job_id] + # If newest result is older than max_job_lifespan, remove the job + if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime: + self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}") + del self.local_results[job_id] + + def add_new_job(self, request, request_json): + """ + Add new dreaming job to the queues. + Request should have: + + + Returns: json object with new job id + """ + self._check_if_should_remove_old_jobs() + + sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS) + if "num_steps_to_predict" not in request_json: + return make_response("num_steps_to_predict not in request", 400) + num_steps_to_predict = request_json['num_steps_to_predict'] + if num_steps_to_predict > self.args.max_dream_steps_per_job: + return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400) + + num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1 + + if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs: + return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400) + + for key in sampling_settings: + sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key] + + context_images = [] + context_actions = [] + context_tokens = [] + future_actions = [] + + for step in request_json["steps"]: + image_path = step["image_name"] + image = np.array(Image.open(request.files[image_path].stream)) + image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height) + context_images.append(th.from_numpy(image)) + + action = step["action"] + action = action_vector_to_be_action_vector(action) + context_actions.append(th.tensor(action)) + + tokens = step["tokens"] + context_tokens.append(tokens) + + future_actions = None + if "future_actions" in request_json: + future_actions = [] + for step in request_json["future_actions"]: + # The rest is the action vector + action = step["action"] + action = action_vector_to_be_action_vector(action) + # Add sequence and batch dimensions + future_actions.append(th.tensor(action)) + + # Add batch dimensions + context_images = th.stack(context_images).unsqueeze(0) + context_actions = th.stack(context_actions).unsqueeze(0) + future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None + + list_of_job_ids = [] + for _ in range(num_parallel_predictions): + job_id = uuid.uuid4().hex + self.jobs.put(DreamJob( + job_id=job_id, + sampling_settings=sampling_settings, + num_predictions_remaining=num_steps_to_predict, + num_predictions_done=0, + context_images=context_images, + context_actions=context_actions, + context_tokens=context_tokens, + actions_to_take=future_actions + )) + list_of_job_ids.append(job_id) + + job_queue_size = self.jobs.qsize() + return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size}) + + def _gather_new_results(self): + if not self.results_queue.empty(): + for _ in range(self.results_queue.qsize()): + result = self.results_queue.get() + if result.job_id in self.cancelled_jobs: + # Discard result if job was cancelled + continue + self.local_results[result.job_id].append(result) + + def get_new_results(self, request, request_json): + if "job_ids" not in request_json: + return make_response("job_ids not in request", 400) + self._gather_new_results() + job_ids = request_json["job_ids"] + if not isinstance(job_ids, list): + job_ids = [job_ids] + return_results = [] + for job_id in job_ids: + if job_id in self.local_results: + return_results.append(self.local_results[job_id]) + del self.local_results[job_id] + + if len(return_results) == 0: + return make_response("No new responses", 204) + + output_json = [] + output_image_bytes = {} + for job_results in return_results: + for result in job_results: + action = result.dreamt_action.numpy() + # Remember to remove batch and sequence dimensions + action = be_action_vector_to_action_vector(action[0, 0].tolist()) + dreamt_tokens = result.dreamt_tokens[0, 0].tolist() + image_filename = f"{result.job_id}_{result.dream_step_index}.png" + output_json.append({ + "job_id": result.job_id, + "dream_step_index": result.dream_step_index, + "action": action, + "tokens": dreamt_tokens, + "image_filename": image_filename + }) + + image_bytes = io.BytesIO() + # this probably is not as smooth as it could be + T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG") + output_image_bytes[image_filename] = image_bytes.getvalue() + + # Write a zip file with all the images + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] + zip_bytes = io.BytesIO() + with zipfile.ZipFile(zip_bytes, "w") as z: + for filename, bytes in output_image_bytes.items(): + z.writestr(filename, bytes) + # Write the json + z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json)) + + zip_bytes.seek(0) + + return send_file( + zip_bytes, + mimetype="zip", + as_attachment=True, + download_name=f"dreaming_results_{timestamp}.zip" + ) + + def cancel_job(self, request, request_json): + if "job_id" not in request_json: + return make_response("job_id not in request", 400) + job_id = request_json["job_id"] + self.cancelled_jobs.add(job_id) + # Cancel all jobs in the queue with this id + for job_queue in self.cancelled_jobs_queues: + job_queue.put(job_id) + return make_response("OK", 200) + + +def main_run(args): + app = Flask(__name__) + + num_workers = th.cuda.device_count() + if num_workers == 0: + raise RuntimeError("No CUDA devices found. Cannot run Dreamer.") + + server = DreamerServer(num_workers, args) + quit_flag = mp.Event() + + # Start the dreamer worker(s) + dreamer_worker_processes = [] + for device_i in range(num_workers): + device = f"cuda:{device_i}" + dreamer_worker_process = mp.Process( + target=dreamer_worker, + args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args) + ) + dreamer_worker_process.daemon = True + dreamer_worker_process.start() + dreamer_worker_processes.append(dreamer_worker_process) + + # Add the API endpoints + @app.route('/') + def details(): + return server.get_details() + + @app.route('/new_job', methods=['POST']) + def new_job(): + request_json = json.loads(request.form["json"]) + return server.add_new_job(request, request_json) + + @app.route('/get_job_results', methods=['GET']) + def get_results(): + # the "Json" is now in regular GET payload/parameters + request_json = {"job_ids": request.args.getlist("job_ids")} + return server.get_new_results(request, request_json) + + @app.route('/cancel_job', methods=['GET']) + def cancel_job(): + request_json = request.args.to_dict() + return server.cancel_job(request, request_json) + + app.run(host="0.0.0.0", port=args.port, debug=args.debug) + + # Cleanup + quit_flag.set() + for dreamer_worker_process in dreamer_worker_processes: + dreamer_worker_process.join() + + +if __name__ == '__main__': + args = parser.parse_args() + main_run(args) diff --git a/setup_local.sh b/setup_local.sh new file mode 100755 index 0000000000000000000000000000000000000000..69c996ab704b648ffc29f43413b6c09f6064262d --- /dev/null +++ b/setup_local.sh @@ -0,0 +1,21 @@ +# Tested using Python 3.9 + +echo "Making and activating a new virtual environment..." +python3.9 -m venv venv + +echo "Activating the virtual environment..." +source venv/bin/activate + +echo "Upgrading pip..." +pip install --upgrade pip + +echo "Instaling the required packages..." +pip install -r requirements.txt + +echo "Instaling the exiftool package for adding file metadata on Linux..." +sudo apt install -y exiftool + +echo "Installing ffmpeg..." +sudo apt install ffmpeg + +echo "All packages installed successfully!" diff --git a/wham/models/nn/model_blocks.py b/wham/models/nn/model_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b35e76b7c94abd1214252377b4768805be0853 --- /dev/null +++ b/wham/models/nn/model_blocks.py @@ -0,0 +1,49 @@ +import torch.nn as nn + +""" +Some Utility blocks for ViT-VQGAN. + +ConvNeXt blocks are based on: +Liu, Zhuang, et al. "A convnet for the 2020s." +Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022. +""" + + +class ConvNextDownsampleBig(nn.Module): + def __init__(self, c_in, c_out): + super().__init__() + self.group_norm = nn.GroupNorm(c_in, c_in) + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=8, stride=4, padding=0) + + def forward(self, x): + return self.conv1(self.group_norm(x)) + + +class ConvNextBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=7, stride=1, padding=7 // 2, groups=channels) # 'Depthwise' conv + self.group_norm = nn.GroupNorm(channels, channels) # Should be equivalent to layernorm + + # Transformer-style non-linearity + self.conv2 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, padding=0) + self.activation = nn.GELU() + self.conv3 = nn.Conv2d(channels * 4, channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + y = self.conv1(x) + y = self.group_norm(y) + y = self.conv2(y) + y = self.activation(y) + y = self.conv3(y) + return x + y + + +class ConvNextDownsample(nn.Module): + def __init__(self, c_in, c_out): + super().__init__() + self.group_norm = nn.GroupNorm(c_in, c_in) + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + return self.conv1(self.group_norm(x)) diff --git a/wham/models/nn/nanoGPT.py b/wham/models/nn/nanoGPT.py new file mode 100644 index 0000000000000000000000000000000000000000..91b7ed5caf7f896a48ed6c999f069d4d3c9c0ab1 --- /dev/null +++ b/wham/models/nn/nanoGPT.py @@ -0,0 +1,665 @@ +# From https://github.com/karpathy/nanoGPT/blob/master/model.py - Thanks Andrej Karpathy + +# MIT License +# Copyright (c) 2022 Andrej Karpathy +# 2023 Microsoft Research + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE. + + +""" +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" + +from dataclasses import dataclass +import inspect +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +NEGATIVE_INFINITE_FLOAT = -float("inf") +CROSS_ENTROPY_INVALID_CLASS_TARGET = -1 + +# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) +def new_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +def limit_logits_to_valid_range(logits, valid_token_range): + """ + MODIFIES logits INPLACE. + Mask out invalid positions in the logits tensor with -inf so they are not considered by the softmax. + + Args: + logits: logits tensor of shape (batch_size, vocab_size) + valid_token_range: tuple of (start, end) indices of valid positions in the logits tensor (inclusive). + Everything outside is masked out with -inf. + """ + logits[:, : valid_token_range[0]] = NEGATIVE_INFINITE_FLOAT + logits[:, valid_token_range[1] + 1 :] = NEGATIVE_INFINITE_FLOAT + + +def default_sample_token(logits, valid_token_range=None, temperature=1.0, deterministic=False, top_k=None, top_p=None, min_tokens_to_keep=1): + """ + Given a vector of logits, sample and return an index according to settings. + + logits: tensor of shape (batch_size, vocab_size) + + valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive). + If None, we'll sample from the full vocab. + + If deterministic is True, we'll take the argmax of the logits which implies top-k sampling with top_k = 1, therefore user inputted values of top_p and top_k will be ignored. + + Otherwise, either top-p (float) value can be specified or top-k (int) value can be specified. + Top-p (float top_p) : only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Top-k (int top_k) : selects top_k tokens for generation. + min_tokens_to_keep: Used with both top_p and top_k sampling. + """ + assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling." + if temperature < 0.1: + # Avoid too low a temp, especially 0 + temperature = 0.1 + logits = logits / temperature + if valid_token_range is not None: + limit_logits_to_valid_range(logits, valid_token_range) + if deterministic: + selected_logits = select_logits(logits, top_k=1) + else: + selected_logits = select_logits(logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep) + probs = F.softmax(selected_logits, dim=-1) + # More robustly handle errors in the sampling here + sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1) + return sampled_idx + + +def select_logits(logits, top_k=None, top_p=None, min_tokens_to_keep=1): + """ + Select from original logits using top-k or top-p sampling. + + Args: + logits (torch.Tensor): Logits to sample from. + k (int, optional): Number of top elements to consider in top-k sampling. + p (float, optional): Threshold probability for top-p sampling. + min_tokens_to_keep (int, optional): Minimum number of tokens to keep in the output. + + Returns: + logits: Selected logits after top-k or top-p sampling. Sets all logits outside the selected ones to NEGATIVE_INFINITE_FLOAT. + """ + assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling." + min_tokens_to_keep = min(min_tokens_to_keep, logits.size(-1)) + if top_k is not None: + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + # Top-k sampling + top_k = max(top_k, min_tokens_to_keep) + top_k = min(top_k, logits.size(-1)) + top_k_logits, _ = torch.topk(logits, top_k) + indices_to_remove = logits < top_k_logits[..., -1:] + logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits) + + elif top_p is not None: + top_p = float(top_p) + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + # Top-p sampling + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove[..., :min_tokens_to_keep] = False + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) + logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits) + + else: + # Return logits as is + pass + + return logits + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + +class LayerNormMinimal(nn.Module): + """LayerNorm like above, but without learnable parameters""" + + def __init__(self, ndim, bias): + super().__init__() + self.ndim = (ndim,) + + def forward(self, input): + return F.layer_norm(input, self.ndim, eps=1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), persistent=False) + + self.cached_k = None + self.cached_v = None + self.current_cache_size = 0 + + def _manual_causal_attention(self, q, k, v, mask): + # q, k and v should be of shape (B, nh, T, hs) + token_len = q.size(-2) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(mask[:, :, :token_len, :token_len] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + return y + + def forward(self, x, cache=False): + batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash and not cache: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + elif cache: + # manual implemention of attention (as below), but cache arrays we can reuse + assert token_len == 1, "Cache only works for single step" + assert self.cached_k is not None, "Must call reset_cache() before using cache" + assert self.current_cache_size < self.cached_k.size(2), "Trying to generate more steps than provided in reset_cache() `num_steps_to_come`" + assert self.dropout == 0.0, "Dropout not supported with caching" + this_step_q = q + self.cached_k[:, :, self.current_cache_size, :] = k[:, :, 0, :] + self.cached_v[:, :, self.current_cache_size, :] = v[:, :, 0, :] + # Remove the zero parts + k = self.cached_k[:, :, : self.current_cache_size + 1, :] + # compute last row of the attention mask + this_step_att_row = (this_step_q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + this_step_att_row = F.softmax(this_step_att_row, dim=-1) + # We only need output for the current step + y = this_step_att_row @ self.cached_v[:, :, : self.current_cache_size + 1, :] + # Update cache + self.current_cache_size += 1 + else: + y = self._manual_causal_attention(q, k, v, self.bias) + y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + def reset_cache(self, x, num_steps_to_come): + """ + Reset caches by doing initial pass with x data (returning same output as forward). + Also set the number of steps to come, which is used to initialize the buffers + """ + batch_size, token_len, n_embd = x.size() + + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # Use SDPA instead of a manual implementation + # y = self._manual_causal_attention(q, k, v, self.bias) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + + y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) + # output projection + y = self.resid_dropout(self.c_proj(y)) + + # Create full k,q,v for predicting all future steps. + # Just null-out the last num_steps_to_come-1 steps + pad_size = num_steps_to_come + self.current_cache_size = token_len + self.cached_k = torch.cat([k, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=k.device)], dim=2) + self.cached_v = torch.cat([v, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=v.device)], dim=2) + + return y + +class SelfAttention(nn.Module): + """ + Non-causal self-attention layer, the same as CausalSelfAttention but without the causal mask. + Duplicating the code to keep this separate for clarity. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 + assert self.flash, "SelfAttention only supports flash attention for now." + + self.register_buffer("attn_mask", torch.ones((config.block_size, config.block_size)).bool().unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout, is_causal=False) + y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = new_gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + +class GELU_MLP(nn.Module): + """MLP Block using PyTorch's native GELU activation function""" + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = F.gelu(x, approximate="tanh") + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None): + """ + Args: + cache: If True, use the cache to predict the next token (assumes model was initialized with `reset_cache`). + reset_cache_with_num_steps_to_come: + If not None, reset and prepare the cache for cached prediction of the next `reset_cache_with_num_steps_to_come` tokens. + This is same as calling `reset_cache` with the same argument, but we include option here in `forward` to support torch hook functions (used to get embeddings from this module output). + + Caching example: + ``` + # Initialize model with reset_cache_with_num_steps_to_come=10 + outputs[0] = model(inputs, reset_cache_with_num_steps_to_come=10) + # Predict next 10 tokens using cache + for i in range(10): + outputs[i+1] = model(inputs, cache=True) + ``` + """ + if reset_cache_with_num_steps_to_come: + return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come) + x = x + self.attn(self.ln_1(x), cache=cache) + x = x + self.mlp(self.ln_2(x)) + return x + + def reset_cache(self, x, num_steps_to_come): + x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come) + x = x + self.mlp(self.ln_2(x)) + return x + +class BlockV2(nn.Module): + """ + Compared to the Block in the original implementation, this one uses non-parametric LayerNorm and Pytorch's GELU. + These two changes save significant vram but are incompatible with previously trained models. + Hence the separate class. + """ + + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNormMinimal(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNormMinimal(config.n_embd, bias=config.bias) + self.mlp = GELU_MLP(config) + + def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None): + if reset_cache_with_num_steps_to_come: + return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come) + x = x + self.attn(self.ln_1(x), cache=cache) + x = x + self.mlp(self.ln_2(x)) + return x + + def reset_cache(self, x, num_steps_to_come): + x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come) + x = x + self.mlp(self.ln_2(x)) + return x + +class SelfAttentionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = SelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + version: int = 1 # Version 1 is the original GPT, Version 2 is the one with non-parametric LayerNorm and Pytorch's GELU + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.version = config.version + + print(f"[nanoGPT] creating model with version {self.version}") + + if self.version == 1: + transformer_dict = dict( + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + elif self.version == 2: + transformer_dict = dict( + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([BlockV2(config) for _ in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), # This one is still parametric due to user error + ) + + transformer_dict["wte"] = nn.Embedding(config.vocab_size, config.n_embd) + self.transformer = nn.ModuleDict(transformer_dict) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. + self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _apply_pos_encoding(self, x): + device = x.device + token_len = x.size(1) + pos = torch.arange(0, token_len, dtype=torch.long, device=device).unsqueeze(0) + pos_emb = self.transformer.wpe(pos) + x = x + pos_emb + return x + + def original_forward(self, idx, targets=None, loss_mask=None, loss_reduction="mean"): + batch_size, seq_len = idx.shape[:2] + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + x = self.transformer.drop(self._apply_pos_encoding(tok_emb)) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + if loss_mask is not None: + # Feeding target = CROSS_ENTROPY_INVALID_CLASS_TARGET to cross_entropy will ignore the loss + # for that position. This is useful for padding tokens. + targets[loss_mask == 0] = CROSS_ENTROPY_INVALID_CLASS_TARGET + loss = F.cross_entropy( + logits.view(batch_size * seq_len, self.config.vocab_size), targets.view(-1), ignore_index=CROSS_ENTROPY_INVALID_CLASS_TARGET, reduction=loss_reduction + ) + if loss_reduction == "none": + # Reshape back into batch_size and seq_len + loss = loss.view(batch_size, seq_len) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + loss = None + + return logits, loss + + def forward(self, x, targets=None, loss_mask=None, loss_reduction="mean"): + token_len = x.size(1) + assert token_len <= self.config.block_size, f"Cannot forward sequence of length {token_len}, block size is only {self.config.block_size}" + return self.original_forward(x, targets, loss_mask, loss_reduction) + + @torch.no_grad() + def generate(self, idx, max_new_tokens, valid_token_range=None, temperature=1.0, top_k=None, raise_cropping=False, deterministic=False): + """ + valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive). + if None, we'll sample from the full vocab. + + If raise_cropping is True, we'll raise an error if we need to crop the sequence context. + """ + if valid_token_range is None: + valid_token_range = (0, self.config.vocab_size - 1) + assert len(valid_token_range) == 2 + assert valid_token_range[0] < valid_token_range[1] + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx + if idx.size(1) > self.config.block_size: + if raise_cropping: + raise ValueError("Tried to crop idxs but flag told to raise this") + else: + idx_cond = idx[:, -self.config.block_size :] + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature # logits is B T Vocabsize -> B Vocabsize + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = NEGATIVE_INFINITE_FLOAT + + # Crop out the logits we don't want to sample from + if valid_token_range is not None: + limit_logits_to_valid_range(logits, valid_token_range) + + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + + if deterministic: + # Take max of the results + idx_next = torch.argmax(probs, dim=-1, keepdim=True) + else: + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx + + @torch.no_grad() + def optimized_generate( + self, + idx, + num_new_tokens, + valid_token_ranges=None, + temperature=1.0, + deterministic=False, + raise_cropping=False, + top_k=None, + top_p=None, + min_tokens_to_keep=1, + ): + """ + Generate function but optimized by caching the results in transformer blocks (think this is referred to as "attention caching"). + The higher the num_new_tokens, the more the speedup compared to original generate. + + Caveat: the context length + num_new_tokens must be less than the block size. This means that the first + generated tokens do not have full context length. + + valid_token_ranges should be None or list of length num_new_tokens, specifying valid range for tokens for every step + """ + # Properly compile the modules used and/or quantize for improved speed. + logit_layer = self.lm_head + embedder_fn = self.transformer.wte + + if valid_token_ranges is None: + valid_token_ranges = [[0, self.config.vocab_size] for _ in range(num_new_tokens)] + assert len(valid_token_ranges) == num_new_tokens, "valid_token_ranges should be list of length num_new_tokens or None" + + _, token_len = idx.size() + if token_len + num_new_tokens > self.config.block_size: + raise ValueError("Can't use optimized generation with num_new_tokens + context_length > block_size") + new_idxs = torch.zeros(idx.size(0), num_new_tokens, dtype=torch.long, device=idx.device) + # First, we need to cull the sequence to the block size + # and remove first max_new_tokens so we can reuse same position embeddings + # and not have to recompute them + num_original_tokens = idx.size(1) + original_idx = idx + if (num_original_tokens + num_new_tokens) > self.config.block_size: + if raise_cropping: + raise ValueError("Tried to crop idxs but flag told to raise this") + original_idx = idx[:, -self.config.block_size + num_new_tokens :] + original_pos = torch.arange(0, original_idx.size(1), dtype=torch.long, device=idx.device).unsqueeze(0) + # Now cache results with the original context + original_tok_emb = embedder_fn(original_idx) + original_pos_emb = self.transformer.wpe(original_pos) + original_x = original_tok_emb + original_pos_emb + for block in self.transformer.h: + # Reset the cache for each block, and cache new result + original_x = block(original_x, reset_cache_with_num_steps_to_come=num_new_tokens) + + # Sample the first token + original_x = self.transformer.ln_f(original_x) + last_logit = logit_layer(original_x[:, [-1], :]) + new_idxs[:, 0] = default_sample_token( + last_logit[:, -1, :], valid_token_ranges[0], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep + ) + + # Generate rest of the steps + for generation_idx in range(1, num_new_tokens): + # forward the model to get the logits for the index in the sequence + # This is the position of the latest generated token, not the currently going-to-be-generated token + latest_token_pos = num_original_tokens + generation_idx - 1 + # We only need to pass in the latest token + newest_idx = new_idxs[:, generation_idx - 1].unsqueeze(-1) + newest_tok_emb = embedder_fn(newest_idx) + newest_pos_emb = self.transformer.wpe(torch.tensor(latest_token_pos, dtype=torch.long, device=idx.device).unsqueeze(0)) + newest_x = newest_tok_emb + newest_pos_emb + for block in self.transformer.h: + newest_x = block(newest_x, cache=True) + + newest_x = self.transformer.ln_f(newest_x) + newest_logit = logit_layer(newest_x) + # Check this function isn't slowing things down noticeably + new_idxs[:, generation_idx] = default_sample_token( + newest_logit[:, -1, :], valid_token_ranges[generation_idx], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep + ) + + # Combine indices + new_idxs = torch.cat((idx, new_idxs), dim=1) + return new_idxs diff --git a/wham/models/pl/__init__.py b/wham/models/pl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wham/models/pl/pl_base_model.py b/wham/models/pl/pl_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e1116e6e0e6b99de542fc03a2becc03834241f --- /dev/null +++ b/wham/models/pl/pl_base_model.py @@ -0,0 +1,5 @@ +import pytorch_lightning as pl + +class BaseTrainingModel(pl.LightningModule): + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/wham/models/vqgan/taming/LICENSE b/wham/models/vqgan/taming/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..695b205f7bec56833b25dc4560ad2868a41b96fe --- /dev/null +++ b/wham/models/vqgan/taming/LICENSE @@ -0,0 +1,24 @@ +All files under this directory are originally from the taming-transformers repository: +https://github.com/CompVis/taming-transformers + +Below is a copy of the original license +------------------------------------------------------------------------------ +Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE./ \ No newline at end of file diff --git a/wham/models/vqgan/taming/model.py b/wham/models/vqgan/taming/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0b9491ebe522d733964ca3e04da7ba3e5a8a06 --- /dev/null +++ b/wham/models/vqgan/taming/model.py @@ -0,0 +1,696 @@ +# All files under this directory are originally from the taming-transformers repository: +# https://github.com/CompVis/taming-transformers + +# MIT License +# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer +# 2023 Microsoft Research + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE. + +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Model(nn.Module): + def __init__( + self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True + ): + super().__init__() + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None): + # assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + c_channels, + resolution, + z_channels, + use_timestep=False, + **ignore_kwargs + ): + super().__init__() + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, z): + # assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h, z), dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/wham/models/vqgan/taming/quantize.py b/wham/models/vqgan/taming/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..8399f63a8f8959ba255fad7ceba0711da4b87069 --- /dev/null +++ b/wham/models/vqgan/taming/quantize.py @@ -0,0 +1,146 @@ +# All files under this directory are originally from the taming-transformers repository: +# https://github.com/CompVis/taming-transformers + +# MIT License +# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer +# 2023 Microsoft Research + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q diff --git a/wham/models/vqgan/taming_vq_model.py b/wham/models/vqgan/taming_vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..01c0743d4fc50a375af696a20e5b3edf2e620037 --- /dev/null +++ b/wham/models/vqgan/taming_vq_model.py @@ -0,0 +1,264 @@ +# Wrapper for the VQ models from the taming-transformers repo +# https://github.com/CompVis/taming-transformers + +from typing import Any, Mapping +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +from wham.models.vqgan.taming.model import Encoder, Decoder +from wham.models.vqgan.taming.quantize import VectorQuantizer2 as VectorQuantizer + +from wham.models.wham_base.tensor_spaces import TensorSpace +from wham.models.wham_base.encoder_decoder import EncoderDecoderBase + + +HARDCODED_IMAGE_SIZE = 128 + + +def taming_vq_preprocess_images(imgs): + """Normalize images (as pytorch tensor uint8s) as in taming-transformers""" + return imgs.float() / 127.5 - 1.0 + + +def taming_vq_revert_preprocess_images(imgs): + """Revert preprocessing of images from taming to uint8 as in taming-transformers""" + # Clamp first + imgs = torch.clamp(imgs, -1.0, 1.0) + return ((imgs + 1) * 127.5).byte() + + +class _VQModelFromTamingRepository(pl.LightningModule): + """ + This aims to be the original VQ model from the taming-transformers repo with as little modifications as possible. This should not be used directly. + Source: https://github.com/CompVis/taming-transformers/blob/master/taming/models/vqgan.py + + MIT License + Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer + 2023 Microsoft Research + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE + OR OTHER DEALINGS IN THE SOFTWARE. + """ + + def __init__( + self, + ddconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + # NOTE: Loss is disabled for this repo (we only want inference) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + # Note: the '!= "None"' check is for checkpoints that mistakenly stored the None as a string + if ckpt_path is not None and ckpt_path != "None": + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + raise NotImplementedError("This copy of the model code does not support training") + + def validation_step(self, batch, batch_idx): + raise NotImplementedError("This copy of the model code does not support training") + + def configure_optimizers(self): + raise NotImplementedError("This copy of the model code does not support training") + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class TamingVQModel(EncoderDecoderBase): + + __DEBUG_CREATION_KWARGS__ = { + "ckpt_path": None, + "model_spec": { + "taming_n_embed": 16, + "taming_embed_dim": 8, + "taming_num_indices_per_axis": 8, + "taming_ddconfig": { + "double_z": False, + "z_channels": 16, + "resolution": 128, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 1, 1, 1, 1], + "num_res_blocks": 1, + "attn_resolutions": [16], + "dropout": 0.0, + }, + }, + } + + def __init__(self, model_spec, ckpt_path, **kwargs): + super().__init__() + self._vocab_size = model_spec["taming_n_embed"] + self.num_indices_per_axis = model_spec["taming_num_indices_per_axis"] + self.num_indices_total = self.num_indices_per_axis**2 + self.taming_embed_dim = model_spec["taming_embed_dim"] + taming_ddconfig = model_spec.get("taming_ddconfig", None) + if taming_ddconfig is None: + raise ValueError("To run TamingVQModel, specify model_spec.taming_ddconfig, which should match the ddconfig used when training the model") + + self.vq_model = _VQModelFromTamingRepository(taming_ddconfig, self._vocab_size, self.taming_embed_dim, ckpt_path=ckpt_path) + + resolution = taming_ddconfig["resolution"] + in_channels = taming_ddconfig["in_channels"] + self.world_space = TensorSpace((in_channels, resolution, resolution), dtype=torch.uint8, low=0, high=255) + self.encoder_space = TensorSpace((self.num_indices_total,), dtype=torch.long, low=0, high=self.vocab_size - 1) + + @property + def vocab_size(self): + """Return the number of entries in the codebook.""" + return self._vocab_size + + @property + def encoded_bottleneck_dim(self): + """Return the dimensionality of the latent vector encoded into codebook indices.""" + return self.num_indices_total + + def _preprocess_images(self, images): + """Preprocess images (B, C, H, W)""" + return taming_vq_preprocess_images(images) + + def _revert_image_preprocess(self, x_batch): + """Revert the preprocessing done in _preprocess_images""" + return taming_vq_revert_preprocess_images(x_batch) + + def decode_from_encoding_indices(self, encoding_indices, return_vq_embeddings=False): + """Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)""" + batch_size = encoding_indices.shape[0] + z = self.vq_model.quantize.get_codebook_entry(encoding_indices, shape=(batch_size, self.num_indices_per_axis, self.num_indices_per_axis, self.taming_embed_dim)) + data_recon = self.vq_model.decode(z) + # Denormalize and cast to uint8 + data_recon = self._revert_image_preprocess(data_recon) + if return_vq_embeddings: + return data_recon, z + return data_recon + + def get_encoding_indices_for_images(self, images): + """ + Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W). + Useful auxiliary method for testing. + """ + x_batch = self._preprocess_images(images) + _, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch) + # Split back into (B, self.encoded_bottleneck_dim) + encoding_indices = encoding_indices.view(images.shape[0], -1) + return encoding_indices + + def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images): + seq_len_dim = 1 + assert images.shape[seq_len_dim] == 1, f"We require seq_len==1, but provided {images.shape[seq_len_dim]}." + images = images.squeeze(dim=seq_len_dim) # get rid of timestep dimension + x_batch = self._preprocess_images(images) + quant, _, (_, _, encoding_indices) = self.vq_model.encode(x_batch) + # Split back into (B, self.encoded_bottleneck_dim) + encoding_indices = encoding_indices.reshape(quant.shape[0], 1, quant.shape[2], quant.shape[3]) + quant = quant.unsqueeze(seq_len_dim) + return None, {"quantized": quant, "encoding_indices": encoding_indices} + + def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor: + batch, time = world_space_tensor.shape[:2] + world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:]) + encodings = self.get_encoding_indices_for_images(world_space_tensor) + # Reshape back to (batch, time, ...) + encodings = encodings.view(batch, time, -1) + return encodings + + def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor: + batch, time = encoder_space_tensor.shape[:2] + encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:]) + decoded = self.decode_from_encoding_indices(encoder_space_tensor) + # Reshape back to (batch, time, ...) + decoded = decoded.view(batch, time, *decoded.shape[1:]) + return decoded diff --git a/wham/models/vqgan/vqgan.py b/wham/models/vqgan/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc63347ad0f15b9ea36943e90039d8d00478f38 --- /dev/null +++ b/wham/models/vqgan/vqgan.py @@ -0,0 +1,236 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from wham.models.wham_base.tensor_spaces import TensorSpace +from wham.models.wham_base.encoder_decoder import EncoderDecoderBase + +from wham.models.vqgan import vqgan_models as vqgan +from wham.models.vqvae.vqvae_utils import make_grid, normalise_rgb, rev_normalise_rgb + +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loggers.wandb import WandbLogger + +TARGET_GAN_UPDATE = 5 +GAN_DWEIGHT_MAX = 250 +GAN_LOGIT_CAP = 5.0 +MAX_PIXEL_WEIGHTING = 0.1 + +# The GAN parts are from Taming Transformers (https://github.com/CompVis/taming-transformers) +""" +ViT-VQGAN is based on: +Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan." +ICLR 2022 +""" + + +def create_vqgan_model_for_training(variant): + return VQGANModel(variant=variant) + + +class VQGANModel(EncoderDecoderBase): + @classmethod + def create_from_variant(cls, variant): + return VQGANModel(variant=variant) + + def __init__(self, variant=None, ckpt_path=None, model_spec=None): + super().__init__() + self.save_hyperparameters() + self.variant = variant + if model_spec is not None: + self.model_spec = model_spec + else: + self.model_spec = variant["model_spec"] + + # Batches of images we will use for logging + self.reference_x_batch = None # Same images used throughout training to see progress of the model + self.random_batch = None # Different images every iteration + + if variant is None and "image_size_per_y_axis" in self.model_spec: + self.image_size_x = self.model_spec["image_size_per_x_axis"] + self.image_size_y = self.model_spec["image_size_per_y_axis"] + else: + assert "image_size_per_x_axis" in variant and "image_size_per_y_axis" in variant, "Please provide the image size as separate x and y for the VQGAN model" + self.image_size_x = variant["image_size_per_x_axis"] + self.image_size_y = variant["image_size_per_y_axis"] + + self._embedding_dim = self.model_spec["embedding_dim"] + self.encoder = vqgan.ViTEncoder( + patch_size=self.model_spec["patch_size"], + transf_dim=self.model_spec["transf_dim"], + embedding_dim=self.model_spec["embedding_dim"], + image_size_x=self.image_size_x, + image_size_y=self.image_size_y, + num_layers=self.model_spec["num_layers"], + head_size=self.model_spec["head_size"], + ) + self._bottleneck_size = self.encoder.bottleneck + + self.vq_vae = vqgan.ViTVectorQuantizer( + self.model_spec["vocab_size"], + self.model_spec["embedding_dim"], + self.model_spec["commitment_cost"], + ) + + self.decoder = vqgan.ViTDecoder( + patch_size=self.model_spec["patch_size"], + transf_dim=self.model_spec["transf_dim"], + embedding_dim=self.model_spec["embedding_dim"], + image_size_x=self.image_size_x, + image_size_y=self.image_size_y, + num_layers=self.model_spec["num_layers"], + head_size=self.model_spec["head_size"], + expected_bottleneck=self._bottleneck_size, + ) + + self.is_perceptual = self.model_spec["is_perceptual"] + assert self.is_perceptual # This should be on + + # Keep track of the usage of the codebook indices + self.codebook_index_usage = np.zeros(self.model_spec["vocab_size"], dtype=np.int64) + + self.gan = self.model_spec.get("use_gan", False) + if self.gan: + # Only make the patchgan if we are using it. This makes it easier to experiment with GAN settings after pretraining the VQ-VAE for instance + self.patch_gan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"]) + # Make a copy of the patchgan since we are only using a single optimizer + self.target_patchgan = vqgan.PatchGan(channel_start=self.model_spec["gan_channel_start"]) + self.target_patchgan.requires_grad_(False) + self.target_patchgan.load_state_dict(self.patch_gan.state_dict()) + self.target_update = TARGET_GAN_UPDATE + + # At which iteration to start using the GAN loss + self.gan_start = self.model_spec["gan_start"] + # How much weight to give to the GAN loss gradients compared to the vq autoencoder loss + self.gan_weight = self.model_spec["gan_weight"] + # How many steps to train the discriminator before applying the gan loss. + self.gan_discrim_pretrain = self.model_spec["gan_discrim_pretrain"] + # How many steps to warmup the gan loss + self.gan_discrim_warmup = self.model_spec["gan_discrim_warmup"] + # Keeping track of the number of updates + self.updates = 0 + print(f"Using GAN with weight {self.gan_weight} and target update {self.target_update} and gan start {self.gan_start} over {self.gan_discrim_warmup} steps") + + self.lpips_model = None + # We don't need this for using the encoder/decoder + # self.lpips_model = lpips.LPIPS(net=self.model_spec["lpips_model"]).eval() + # for param in self.lpips_model.parameters(): + # param.requires_grad = False + + if ckpt_path is not None and ckpt_path != "None": + print(f"Initing VQGAN model from {ckpt_path}") + loaded_ckpt = torch.load(ckpt_path, map_location="cpu") + # Can ignore stuff here + self.load_state_dict(loaded_ckpt["state_dict"], strict=False) + + self.world_space = TensorSpace((3, self.image_size_y, self.image_size_x), dtype=torch.uint8, low=0, high=255) + self.encoder_space = TensorSpace((self._bottleneck_size,), dtype=torch.long, low=0, high=self.vocab_size - 1) + + @property + def vocab_size(self): + """Return the number of entries in the codebook.""" + return self.vq_vae._vocab_size + + @property + def encoded_bottleneck_dim(self): + """Return the dimensionality of the latent vector encoded into codebook indices.""" + return self._bottleneck_size + + @property + def embedding_dim(self): + """The dimensionality of quantized vectors (the dimension of codebook vectors).""" + return self.vq_vae._embedding_dim + + def _get_last_layer(self): + """ + The last layer used for generating the image. + Used for balancing the gradients of the reconstruction and the GAN loss. + """ + return self.decoder.get_last_layer() + + def _preprocess_images(self, images): + """Preprocess images (B, C, H, W)""" + x_batch = images.float() / 255 + x_batch = normalise_rgb(x_batch) + return x_batch + + def _revert_image_preprocess(self, x_batch): + """Revert the preprocessing done in _preprocess_images""" + normalized_imgs = rev_normalise_rgb(x_batch.clone()) + x_batch = torch.clip(normalized_imgs, 0, 1) + images = (x_batch * 255).byte() + return images + + def _get_latent_continuous(self, batch): + z = self.encoder(batch) + return z + + def _get_latent_discretized(self, z): + z_quantized, vq_loss, perplexity, indices = self.vq_vae(z) + return z_quantized, vq_loss, perplexity, indices + + def _encode_decode(self, x_batch): + z = self._get_latent_continuous(x_batch) + z_quantized, vq_loss, perplexity, indices = self._get_latent_discretized(z) + data_recon = self.decoder(z_quantized) + return vq_loss, perplexity, data_recon, indices + + def _log_vars(self, log_vars): + prefix = "train" if self.training else "val" + for key, val in log_vars.items(): + self.log(f"{prefix}/{key}", val, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) + + def decode_from_encoding_indices(self, encoding_indices): + """Return decoded images (B, C, H, W) for a batch of encoding indices (B, self.encoded_bottleneck_dim)""" + z = self.vq_vae.convert_encoding_indices_to_quantized_embeddings(encoding_indices) + data_recon = self.decoder(z) + # Denormalize and cast to uint8 + data_recon = self._revert_image_preprocess(data_recon) + return data_recon + + def get_encoding_indices_for_images(self, images): + """ + Return encoding indices (B, self.encoded_bottleneck_dim) for a batch of images (B, C, H, W). + Useful auxiliary method for testing. + """ + x_batch = self._preprocess_images(images) + z = self._get_latent_continuous(x_batch) + encoding_indices = self.vq_vae(z, only_return_encoding_indices=True) + return encoding_indices + + def forward_returning_action_and_embedding(self, states, actions_input, timesteps, attention_mask, images): + raise NotImplementedError + + def get_encoding_output(self, images): + """ + Return outputs from the encoder for a batch of images (B, C, H, W). + Returns: + quantized_z: (B, self.encoded_bottleneck_dim, self.embedding_dim), quantized latent vectors with straight-through gradient estimator + vq_loss: (B, ), VQ loss for each image + perplexity: (B, ), perplexity for each image + encoding_indices: (B, self.encoded_bottleneck_dim), encoding indices for each image + """ + x_batch = self._preprocess_images(images) + z = self._get_latent_continuous(x_batch) + quantized_z, vq_loss, perplexity, encoding_indices = self.vq_vae(z) + quantized_z = quantized_z.view(quantized_z.shape[0], self.encoded_bottleneck_dim, self.embedding_dim) + return quantized_z, vq_loss, perplexity, encoding_indices + + def _encode(self, world_space_tensor: torch.tensor) -> torch.tensor: + # Flatten time and batch dim into one + batch, time = world_space_tensor.shape[:2] + world_space_tensor = world_space_tensor.view(batch * time, *world_space_tensor.shape[2:]) + encodings = self.get_encoding_indices_for_images(world_space_tensor) + # Reshape back to (batch, time, ...) + encodings = encodings.view(batch, time, -1) + return encodings + + def _decode(self, encoder_space_tensor: torch.tensor) -> torch.tensor: + # Flatten time and batch dim into one + batch, time = encoder_space_tensor.shape[:2] + encoder_space_tensor = encoder_space_tensor.view(batch * time, *encoder_space_tensor.shape[2:]) + decoded = self.decode_from_encoding_indices(encoder_space_tensor) + # Reshape back to (batch, time, ...) + decoded = decoded.view(batch, time, *decoded.shape[1:]) + return decoded diff --git a/wham/models/vqgan/vqgan_models.py b/wham/models/vqgan/vqgan_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe2cb21eff53d558fbdd80dc0b440ac275aff64 --- /dev/null +++ b/wham/models/vqgan/vqgan_models.py @@ -0,0 +1,311 @@ +# MIT License +# Copyright (c) 2018 Zalando Research +# 2023 Microsoft Research + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE. + +from math import sqrt + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from wham.models.nn.nanoGPT import GPTConfig, SelfAttentionBlock +from wham.models.nn.model_blocks import ConvNextBlock, ConvNextDownsample, ConvNextDownsampleBig + +# Mainly following https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb +""" +ViT-VQGAN is based on: +Yu, Jiahui, et al. "Vector-quantized image modeling with improved vqgan." +ICLR 2022 +""" + + +def _convert_encoding_indices_to_quantized_embeddings(encoding_indices, embedding_layer, vocab_size, embedding_dim): + """ + Args: + encoding_indices: tensor of integers (batch_size, bottleneck_size) + Each batch item represents a single image as a sequence of integers (indeces of codebook vectors) + Output: + quantized: tensor of floats (batch_size, bottleneck_size, embedding_dim) + """ + batch_dim, bottleneck_size = encoding_indices.shape[:2] + + encoding_indices = encoding_indices.view(-1).unsqueeze(1) + one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], vocab_size, device=encoding_indices.device) + one_hot_encoding_indices.scatter_(1, encoding_indices, 1) + + quantized = torch.matmul(one_hot_encoding_indices, embedding_layer) + quantized = quantized.view(batch_dim, bottleneck_size, embedding_dim).contiguous() + return quantized + + +class ViTVectorQuantizer(nn.Module): + """ + Vector Quantizer for a Vision Transformer based VQ model using normalised codebook embeddings as in https://arxiv.org/abs/2110.04627. + """ + + def __init__(self, vocab_size, embedding_dim, commitment_cost, epsilon=1e-5): + super().__init__() + + self._embedding_dim = embedding_dim + self._vocab_size = vocab_size + self._epsilon = epsilon + + self._embedding = nn.Embedding(self._vocab_size, self._embedding_dim) + self._embedding.weight.data.uniform_(-1 / self._vocab_size, 1 / self._vocab_size) + self._commitment_cost = commitment_cost + + @property + def vocab_size(self): + """Return the number of entries in the codebook.""" + return self._vocab_size + + def convert_encoding_indices_to_quantized_embeddings(self, encoding_indices): + """ + Args: + encoding_indices: tensor of integers (batch_size, bottleneck_size) + Each batch item represents a single image as a sequence of integers (indeces of codebook vectors) + Output: + quantized: tensor of floats (batch_size, self._embedding_dim, bottleneck_size) + """ + return _convert_encoding_indices_to_quantized_embeddings(encoding_indices, F.normalize(self._embedding.weight), self._vocab_size, self._embedding_dim) + + def forward(self, inputs, only_return_encoding_indices=False): + """ + If only_return_encoding_indices is True, then only return the indices of codebook vectors + """ + input_shape = inputs.shape + + # Flatten input from Batch Tokens Embedding to B*T E + flat_input = inputs.view(-1, self._embedding_dim) + # Normalize inputs + flat_input = F.normalize(flat_input) + + # Embeddings are always normalized + embeddings_to_use = F.normalize(self._embedding.weight) + + # Calculate distances + distances = torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(embeddings_to_use**2, dim=1) - 2 * torch.matmul(flat_input, embeddings_to_use.t()) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + if only_return_encoding_indices: + # Add back batch dimension + return encoding_indices.view(input_shape[0], -1) + one_hot_encoding_indices = torch.zeros(encoding_indices.shape[0], self._vocab_size, device=inputs.device) + one_hot_encoding_indices.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(one_hot_encoding_indices, embeddings_to_use).view(input_shape) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + q_latent_loss = F.mse_loss(quantized, inputs.detach()) + loss = q_latent_loss + self._commitment_cost * e_latent_loss + + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(one_hot_encoding_indices, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self._epsilon))) + + return quantized, loss, perplexity, encoding_indices.view(input_shape[0], -1) + + +class ViTEncoder(nn.Module): + def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size): + super().__init__() + + self.image_size_x = image_size_x + self.image_size_y = image_size_y + # We will pad the image to make it divisible by patch_size + self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size + self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size + assert (self.image_size_x + self.x_pad) % patch_size == 0 and ( + self.image_size_y + self.y_pad + ) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size" + + self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size) + self._bottleneck = self.vit_tokens + print(f"Bottleneck is {self.bottleneck} for image size {image_size_x}x{image_size_y} with ViT Encoder and patch size {patch_size}") + + self.patch_size = patch_size + self.transf_dim = transf_dim + self.embedding_dim = embedding_dim + + self.proj1 = nn.Linear(3 * patch_size * patch_size, transf_dim) + self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim) + + assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size" + n_heads = self.transf_dim // head_size + transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0) + self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)]) + + self.output_ln = nn.LayerNorm(transf_dim) + self.output_proj = nn.Linear(transf_dim, embedding_dim) + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer)) + + @property + def bottleneck(self): + return self._bottleneck + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, inputs): + # inputs: (batch_size, 3, image_size_x, image_size_y) + + # Patch input images + batch_size = inputs.shape[0] + padded_inputs = F.pad(inputs, (0, self.x_pad, 0, self.y_pad), mode="constant", value=0) + x = padded_inputs.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) + num_x_patches = (self.image_size_x + self.x_pad) // self.patch_size + num_y_patches = (self.image_size_y + self.y_pad) // self.patch_size + + # inputs is of shape (batch_size, 3, num_x_patches, num_y_patches, patch_size, patch_size) + # Turn it into (batch_size, patches, input_dim) + patches = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(batch_size, num_x_patches * num_y_patches, 3 * self.patch_size * self.patch_size) + + proj_patches = self.proj1(patches) + + pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1) + vit_input = proj_patches + pos_embeds + vit_output = self.vit(vit_input) + + vit_output = self.output_ln(vit_output) + embeddings = self.output_proj(vit_output) + normalised_embeddings = F.normalize(embeddings, dim=-1) + + return normalised_embeddings + + +class ViTDecoder(nn.Module): + def __init__(self, patch_size, transf_dim, embedding_dim, image_size_x, image_size_y, num_layers, head_size, expected_bottleneck=None): + super().__init__() + + self.image_size_x = image_size_x + self.image_size_y = image_size_y + self.x_pad = (patch_size - (self.image_size_x % patch_size)) % patch_size + self.y_pad = (patch_size - (self.image_size_y % patch_size)) % patch_size + + assert (self.image_size_x + self.x_pad) % patch_size == 0 and ( + self.image_size_y + self.y_pad + ) % patch_size == 0, "image_size_x and image_size_y must be divisible by patch_size" + + self.vit_tokens = ((image_size_x + self.x_pad) // patch_size) * ((image_size_y + self.y_pad) // patch_size) + if expected_bottleneck is not None: + assert ( + self.vit_tokens == expected_bottleneck + ), f"Expected bottleneck of {expected_bottleneck} but got {self.vit_tokens} for image size {image_size_x}x{image_size_y} with ViT Decoder and patch size {patch_size}" + + self.patch_size = patch_size + self.transf_dim = transf_dim + self.embedding_dim = embedding_dim + + self.proj1 = nn.Linear(embedding_dim, transf_dim) + self.pos_embeds = nn.Embedding(self.vit_tokens, transf_dim) + + assert self.transf_dim % head_size == 0, "transf_dim must be divisible by head_size" + n_heads = self.transf_dim // head_size + transformer_config = GPTConfig(block_size=self.vit_tokens, n_layer=num_layers, n_head=n_heads, n_embd=transf_dim, bias=False, dropout=0) + self.vit = nn.Sequential(*[SelfAttentionBlock(transformer_config) for _ in range(num_layers)]) + + self.output_ln = nn.LayerNorm(transf_dim) + self.output_proj = nn.Linear(transf_dim, 3 * patch_size * patch_size) + + # Couldn't resist the name + self.folder = nn.Fold( + output_size=(self.image_size_y + self.y_pad, self.image_size_x + self.x_pad), + kernel_size=(self.patch_size, self.patch_size), + stride=(self.patch_size, self.patch_size), + ) + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / sqrt(2 * transformer_config.n_layer)) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, inputs): + # Patch input images + batch_size = inputs.shape[0] + + # Unproject the embeddings from the VQ embedding space to the transformer space + proj_patches = self.proj1(inputs).reshape(batch_size, self.vit_tokens, self.transf_dim) + + pos_embeds = self.pos_embeds.weight.unsqueeze(0).repeat(batch_size, 1, 1) + vit_input = proj_patches + pos_embeds + vit_output = self.vit(vit_input) + + vit_output = self.output_ln(vit_output) + + predictions = self.output_proj(vit_output) # (batch, patches, 3 * patch_size * patch_size) + + # Reassemble the image into (batch, 3, image_size_x, image_size_y) + fold_inputs = predictions.permute(0, 2, 1).contiguous() + image_pred = self.folder(fold_inputs) + + unpadded_image_pred = image_pred[:, :, : self.image_size_y, : self.image_size_x] # Remove padding in the same way it was applied in the encoder + + # Anything on the output? + return unpadded_image_pred + + def get_last_layer(self): + """ + Return the last layer weights of the model, to use for loss balancing. + """ + return self.output_proj.weight + + +class PatchGan(nn.Module): + def __init__(self, channel_start): + super().__init__() + x = channel_start + self.downsample1 = ConvNextDownsampleBig(3, x) + self.block1 = ConvNextBlock(x) + self.downsample2 = ConvNextDownsampleBig(x, x) + self.block2 = ConvNextBlock(x) + self.last = nn.Conv2d(x, 1, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + batch_size = x.shape[0] + y = torch.nn.functional.gelu(self.downsample1(x)) + y = self.block1(y) + z = torch.nn.functional.gelu(self.downsample2(y)) + z = self.block2(z) + return self.last(z).reshape(batch_size, -1) \ No newline at end of file diff --git a/wham/models/vqvae/vqvae_utils.py b/wham/models/vqvae/vqvae_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81e3718a5f81977a066a7985c9642b74f5c5a8b0 --- /dev/null +++ b/wham/models/vqvae/vqvae_utils.py @@ -0,0 +1,154 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch + + +def normalise_rgb(X, channels_first=True): + """ + Take in an image tensor of shape [ ... , 3], which is assumed to have already been divided by + 255 so X \in [0,1]. These functions do additional normalisation, roughly ending up with mean + of zero and unit variance. The constants appeared in most vision repos, and are supposedly the + 'right' constants to use based on imagenet statistics. + assert X.shape[-1] == 3 + """ + channel_dim = 1 if channels_first else -1 + assert X.shape[channel_dim] == 3 + if channels_first: + X[:, 0, ...] -= 0.485 + X[:, 0, ...] /= 0.229 + X[:, 1, ...] -= 0.456 + X[:, 1, ...] /= 0.224 + X[:, 2, ...] -= 0.406 + X[:, 2, ...] /= 0.225 + else: + X[..., 0] -= 0.485 + X[..., 0] /= 0.229 + X[..., 1] -= 0.456 + X[..., 1] /= 0.224 + X[..., 2] -= 0.406 + X[..., 2] /= 0.225 + return X + + +def rev_normalise_rgb(X, channels_first=True): + """ + Reverse `normalise_rgb`, so the output lives in [0,1]. This function is needed for + reconstruction visualisation, etc. + """ + channel_dim = 1 if channels_first else -1 + assert X.shape[channel_dim] == 3 + if channels_first: + X[:, 0, ...] *= 0.229 + X[:, 0, ...] += 0.485 + X[:, 1, ...] *= 0.224 + X[:, 1, ...] += 0.456 + X[:, 2, ...] *= 0.225 + X[:, 2, ...] += 0.406 + else: + X[..., 0] *= 0.229 + X[..., 0] += 0.485 + X[..., 1] *= 0.224 + X[..., 1] += 0.456 + X[..., 2] *= 0.225 + X[..., 2] += 0.406 + return X + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, + **kwargs, +) -> torch.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + if not torch.is_tensor(tensor): + if isinstance(tensor, list): + for t in tensor: + if not torch.is_tensor(t): + raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") + else: + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if not isinstance(tensor, torch.Tensor): + raise TypeError("tensor should be of type torch.Tensor") + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy_(tensor[k]) # type: ignore[attr-defined] + k = k + 1 + return grid \ No newline at end of file diff --git a/wham/models/wham_base/__init__.py b/wham/models/wham_base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wham/models/wham_base/encode_predict_decode_base.py b/wham/models/wham_base/encode_predict_decode_base.py new file mode 100644 index 0000000000000000000000000000000000000000..efe4df6426344490d8494b525187848a00f99b65 --- /dev/null +++ b/wham/models/wham_base/encode_predict_decode_base.py @@ -0,0 +1,256 @@ +from typing import Any, Union, Type, Callable, Tuple, Mapping, Optional + +import torch as th +import pytorch_lightning as pl +from tensordict import TensorDict # type: ignore # requires installing stubs for tensordict + +from .tensor_spaces import TensorDictSpace +from .encoder_decoder import EncoderDecoderBase +from .pl_creation_args import LightningModuleCreationArgs + + +def create_encoder_args_from_config_dict( + config_dict: dict[str, Union[dict[str, Any], tuple]], class_name_to_model: Callable[[str], Type[pl.LightningModule]] +) -> Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]]: + """ + Given a dictionary mapping modality names to their encoder-decoder arguments, create the corresponding + creation args (LightningModuleCreationArgs) for each modality. + + See LightningModuleCreationArgs.from_dict for more details. + + Args: + config_dict: A dictionary mapping modality names to their encoder-decoder arguments. + Root level of this dictionary should be modality names we expect. + class_name_to_model: A function mapping class names to their corresponding model classes. + + Returns: + A dictionary mapping modality names to their encoder-decoder creation args. + Each value may be a LightningModuleCreationArgs, or a tuple of two LightningModuleCreationArgs. + If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding. + If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding. + """ + # Giving explicit type hint here to make mypy happy + modalities: dict[str, Any] = {} + for modality_name, modality_config in config_dict.items(): + if isinstance(modality_config, (list, tuple)): + assert len(modality_config) == 2, f"Expected two entries for modality {modality_name}, got {len(modality_config)}" + modalities[modality_name] = ( + LightningModuleCreationArgs.from_dict(modality_config[0], class_name_to_model), + LightningModuleCreationArgs.from_dict(modality_config[1], class_name_to_model), + ) + else: + modalities[modality_name] = LightningModuleCreationArgs.from_dict(modality_config, class_name_to_model) + return modalities + + +def create_encoder_modules_from_args( + encoders: Mapping[str, Union[LightningModuleCreationArgs, Tuple[LightningModuleCreationArgs, LightningModuleCreationArgs]]], remove_checkpoint_path: bool = True +) -> th.nn.ModuleDict: + """ + Create the encoder modules from given creation args (LightningModuleCreationArgs). + + Args: + encoders: A dictionary mapping modality names to their encoder-decoder creation args. + If value is a LightningModuleCreationArgs, then same model is used for encoding and decoding. + If value is a tuple of two LightningModuleCreationArgs, then first is used for encoding and second for decoding. + remove_checkpoint_path: If True, then remove the checkpoint_path from the creation args. This prepares the + created moduled to be properly saved and loaded as part of the bigger model + + Returns: + A dictionary mapping modality names to their encoder-decoder modules. + """ + modalities = {} + for modality_name, modality_args in encoders.items(): + if isinstance(modality_args, (list, tuple)): + modalities[modality_name] = th.nn.ModuleList( + [ + modality_args[0].create_module(remove_checkpoint_path=remove_checkpoint_path), + modality_args[1].create_module(remove_checkpoint_path=remove_checkpoint_path), + ] + ) + else: + modalities[modality_name] = modality_args.create_module(remove_checkpoint_path=remove_checkpoint_path) + return th.nn.ModuleDict(modalities) + + +class EncodePredictDecodeModule(pl.LightningModule): + """ + Base-class for models that encode, predict and decode. + + Args: + context_encoders: A dictionary mapping modality names to their encoder-decoders. + If value is a pl.LightningModule, then same model is used for encoding and decoding. + If value is a tuple of two pl.LightningModule, then first is used for encoding and second for decoding. + condition_encoders: Same as `context_encoders`, but for conditions. + """ + + def __init__( + self, + predictor_args: LightningModuleCreationArgs, + context_encoders: th.nn.ModuleDict, + condition_encoders: Optional[th.nn.ModuleDict] = None, + ): + if condition_encoders is None: + condition_encoders = th.nn.ModuleDict(dict()) + self._assert_encoders(context_encoders) + self._assert_encoders(condition_encoders) + super().__init__() + + self.context_encoders = context_encoders + self.condition_encoders = condition_encoders + + self.context_world_space, self.context_encoder_space = self._get_spaces_from_encoders(context_encoders) + self.condition_world_space, self.condition_encoder_space = self._get_spaces_from_encoders(condition_encoders) + + self.predictor = predictor_args.create_module(context_space=self.context_encoder_space, condition_space=self.condition_encoder_space) + + def _assert_encoders(self, encoders: th.nn.ModuleDict) -> None: + """Check that encoder dictionary is valid""" + assert isinstance(encoders, th.nn.ModuleDict), f"Invalid type for encoders: {type(encoders)}. Expected th.nn.ModuleDict" + for modality_name, encoder in encoders.items(): + assert isinstance(encoder, EncoderDecoderBase) or isinstance( + encoder, th.nn.ModuleList + ), f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or Tuple[EncoderDecoderBase]" + if isinstance(encoder, th.nn.ModuleList): + assert len(encoder) == 2, f"Invalid number of arguments for modality {modality_name}: {len(encoder)}. Expected two (encoder, decoder)" + assert isinstance( + encoder[0], EncoderDecoderBase + ), f"Invalid type for encoder of modality {modality_name}: {type(encoder[0])}. Expected EncoderDecoderBase" + assert isinstance( + encoder[1], EncoderDecoderBase + ), f"Invalid type for decoder of modality {modality_name}: {type(encoder[1])}. Expected EncoderDecoderBase" + + def _get_spaces_from_encoders(self, encoders: th.nn.ModuleDict) -> Tuple[TensorDictSpace, TensorDictSpace]: + """ + Given a modality dictionary mapping modality names to their encoders and decoders, + extract the world space and encoder space, + """ + world_spaces = {} + encoder_spaces = {} + for modality_name, modality in encoders.items(): + if isinstance(modality, EncoderDecoderBase): + encoder_spaces[modality_name] = modality.encoder_space + world_spaces[modality_name] = modality.world_space + elif isinstance(modality, th.nn.ModuleList): + assert len(modality) == 2, f"Invalid number of modules for modality {modality_name}: {len(modality)}. Expected 2." + # Make sure that both encoder and decoder spaces match the expected space + encoder_encoder_space = modality[0].encoder_space + decoder_encoder_space = modality[1].encoder_space + assert ( + encoder_encoder_space == decoder_encoder_space + ), f"Encoder and decoder spaces for modality {modality_name} do not match: {encoder_encoder_space} != {decoder_encoder_space}" + encoder_world_space = modality[0].world_space + decoder_world_space = modality[1].world_space + assert ( + encoder_world_space == decoder_world_space + ), f"Encoder and decoder world spaces for modality {modality_name} do not match: {encoder_world_space} != {decoder_world_space}" + encoder_spaces[modality_name] = encoder_encoder_space + world_spaces[modality_name] = encoder_world_space + else: + raise TypeError(f"Invalid type for modality {modality_name}: {type(modality)}. Expected EncoderDecoderBase or th.nn.ModuleList") + return TensorDictSpace(world_spaces), TensorDictSpace(encoder_spaces) + + def _encode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict: + """ + Encode input_td into encoder space using the given encoders. + + Args: + input_td: A tensordict mapping modality names to their inputs. + encoders: A dictionary mapping modality names to their encoders. + + Returns: + An encoded tensordict. + """ + encoded_context = {} + preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True) + for modality_name in input_td.keys(): + encoder = encoders[modality_name] + if isinstance(encoder, EncoderDecoderBase): + encoded_context[modality_name] = encoder.encode(input_td[modality_name]) + elif isinstance(encoder, th.nn.ModuleList): + encoded_context[modality_name] = encoder[0].encode(input_td[modality_name]) + else: + raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList") + return TensorDict(encoded_context, batch_size=preceding_dims) + + def _decode(self, input_td: TensorDict, encoders: th.nn.ModuleDict, space: TensorDictSpace) -> TensorDict: + """ + Decode input_td into the original space using the given encoders. + + Args: + input_td: A tensordict mapping modality names to their encoded inputs. + encoders: A dictionary mapping modality names to their encoders. + + Returns: + A decoded tensordict. + """ + decoded_context = {} + preceding_dims = space.get_preceding_dimensions(input_td, allow_key_subset=True) + for modality_name in input_td.keys(): + encoder = encoders[modality_name] + if isinstance(encoder, EncoderDecoderBase): + decoded_context[modality_name] = encoder.decode(input_td[modality_name]) + elif isinstance(encoder, th.nn.ModuleList): + decoded_context[modality_name] = encoder[1].decode(input_td[modality_name]) + else: + raise TypeError(f"Invalid type for modality {modality_name}: {type(encoder)}. Expected EncoderDecoderBase or th.nn.ModuleList") + return TensorDict(decoded_context, batch_size=preceding_dims) + + def encode_context(self, context: TensorDict) -> TensorDict: + """ + Encode the given context into the encoder space. + + Args: + context: A tensordict mapping modality names to their inputs. + + Returns: + An encoded tensordict. + """ + assert self.context_world_space.contains(context, allow_key_subset=True), f"Context {context} is not contained in context world space {self.context_world_space}" + return self._encode(context, self.context_encoders, self.context_world_space) + + def decode_context(self, encoded_context: TensorDict) -> TensorDict: + """ + Decode the given encoded context into the original space. + + Args: + encoded_context: A tensordict mapping modality names to their encoded inputs. + + Returns: + A decoded tensordict. + """ + assert self.context_encoder_space.contains( + encoded_context, + allow_key_subset=True, + ), f"Encoded context {encoded_context} is not contained in context encoder space {self.context_encoder_space}" + return self._decode(encoded_context, self.context_encoders, self.context_encoder_space) + + def encode_condition(self, condition: TensorDict) -> TensorDict: + """ + Encode the given condition into the encoder space. + + Args: + condition: A tensordict mapping modality names to their inputs. + + Returns: + An encoded tensordict. + """ + assert self.condition_world_space.contains( + condition, allow_key_subset=True + ), f"Condition {condition} is not contained in condition world space {self.condition_world_space}" + return self._encode(condition, self.condition_encoders, self.condition_world_space) + + def decode_condition(self, encoded_condition: TensorDict) -> TensorDict: + """ + Decode the given encoded condition into the original space. + + Args: + encoded_condition: A tensordict mapping modality names to their encoded inputs. + + Returns: + A decoded tensordict. + """ + assert self.condition_encoder_space.contains( + encoded_condition, allow_key_subset=True + ), f"Encoded condition {encoded_condition} is not contained in condition encoder space {self.condition_encoder_space}" + return self._decode(encoded_condition, self.condition_encoders, self.condition_encoder_space) diff --git a/wham/models/wham_base/encoder_decoder.py b/wham/models/wham_base/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ef62119304844b1d1300d86643347d7bceda7c90 --- /dev/null +++ b/wham/models/wham_base/encoder_decoder.py @@ -0,0 +1,113 @@ +from abc import ABC +from typing import Any + +import torch as th +import pytorch_lightning as pl + +# Using relative import so that this module is easier to move from this place to elsewhere +from .tensor_spaces import TensorSpace + + +class EncoderDecoderBase(pl.LightningModule): + """ + Base class for all encoders and decoders. + + Encoders turn datapoints from "world_space" to "encoder_space". + Decoders turn datapoints from "encoder_space" to "world_space". + + All tensors are in format (batch, time, ...), where 'batch' and 'time' dimensions + are always present (even if they are 1). Both world and encoder spaces can have + any number of dimensions in the '...' part. + """ + + # This is a dictionary of keyword arguments that can be used to create this class + # during testing/quick debugging (e.g., minimal model size) + __DEBUG_CREATION_KWARGS__: dict[str, Any] = dict() + + def __init__(self): + super().__init__() + self._world_space = None + self._encoder_space = None + + @property + def world_space(self) -> TensorSpace: + assert self._world_space is not None, "'world_space' is not defined. Set it with 'self.world_space = [TensorSpace]'." + return self._world_space + + @world_space.setter + def world_space(self, value: TensorSpace) -> None: + assert isinstance(value, TensorSpace), f"'world_space' must be of type TensorSpace, but is {type(value)}" + self._world_space = value + + @property + def encoder_space(self) -> TensorSpace: + assert self._encoder_space is not None, "'encoder_space' is not defined. Set it with 'self.encoder_space = [TensorSpace]'." + return self._encoder_space + + @encoder_space.setter + def encoder_space(self, value: TensorSpace) -> None: + assert isinstance(value, TensorSpace), f"'encoder_space' must be of type TensorSpace, but is {type(value)}" + self._encoder_space = value + + def encode(self, world_space_tensor: th.Tensor) -> th.Tensor: + """ + Encodes a tensor from world space to encoder space. + + The input tensor should match the world space of this encoder. + The input tensor may have any number of preceding dimensions (batch, time, ...), + and output result will be parallelly encoded for the preceding dimensions. + + Args: + world_space_tensor: Pytorch Tensor in world space (self.world_space.contains(world_space_tensor) == True)s + Returns: + Pytorch Tensor in encoder space (self.encoder_space.contains(return_value) == True) + """ + if not self.world_space.contains(world_space_tensor): + raise ValueError(f"Input tensor to `encode` {world_space_tensor} is not in world space {self.world_space}") + + preceding_dims = self.world_space.get_preceding_dimensions(world_space_tensor) + encoder_space_tensor = self._encode(world_space_tensor) + + if not self.encoder_space.contains(encoder_space_tensor): + raise ValueError(f"Output tensor from `_encode` {encoder_space_tensor} is not in encoder space {self.encoder_space}") + + new_preceding_dims = self.encoder_space.get_preceding_dimensions(encoder_space_tensor) + if new_preceding_dims != preceding_dims: + raise ValueError(f"Output tensor from `_encode` has preceding dimensions {new_preceding_dims}, but input tensor had preceding dimensions {preceding_dims}") + + return encoder_space_tensor + + def decode(self, encoder_space_tensor: th.Tensor) -> th.Tensor: + """ + Decodes a tensor from encoder space to world space. + + The input tensor should match the encoder space of this decoder. + The input tensor may have any number of preceding dimensions (batch, time, ...), + and output result will be parallelly decoded for the preceding dimensions. + + Args: + encoder_space_tensor: Pytorch Tensor in encoder space (self.encoder_space.contains(encoder_space_tensor) == True) + Returns: + Pytorch Tensor in world space (self.world_space.contains(return_value) == True) + """ + if not self.encoder_space.contains(encoder_space_tensor): + raise ValueError(f"Input tensor to `decode` {encoder_space_tensor} is not in encoder space {self.encoder_space}") + + preceding_dims = self.encoder_space.get_preceding_dimensions(encoder_space_tensor) + world_space_tensor = self._decode(encoder_space_tensor) + + if not self.world_space.contains(world_space_tensor): + raise ValueError(f"Output tensor from `_decode` {world_space_tensor} is not in world space {self.world_space}") + + # Make sure that the output tensor has the same preceding dimensions as the input tensor + new_preceding_dims = self.world_space.get_preceding_dimensions(world_space_tensor) + if new_preceding_dims != preceding_dims: + raise ValueError(f"Output tensor from `_decode` has preceding dimensions {new_preceding_dims}, but input tensor had preceding dimensions {preceding_dims}") + + return world_space_tensor + + def _encode(self, world_space_tensor: th.Tensor) -> th.Tensor: + raise NotImplementedError("Encoder function `_encode` not implemented") + + def _decode(self, encoder_space_tensor: th.Tensor) -> th.Tensor: + raise NotImplementedError("Decoder function `_decode` not implemented") diff --git a/wham/models/wham_base/pl_creation_args.py b/wham/models/wham_base/pl_creation_args.py new file mode 100644 index 0000000000000000000000000000000000000000..0df9180e969feaad1332240a25e3c4c3ad2eeb58 --- /dev/null +++ b/wham/models/wham_base/pl_creation_args.py @@ -0,0 +1,121 @@ +from typing import Any, Type, Callable, Optional + +import pytorch_lightning as pl + + +class LightningModuleCreationArgs: + """ + A creator class for holding arguments to define creating/loading a PL Module, either from scratch or from a checkpoint. + + Three combinations are possible: + - `pl_class` is provided, create a new module. Pass in `**pl_kwargs` (defaults to empty dict) as kwargs. + - `pl_class` and `pl_checkpoint_path` are provided, load a module from a checkpoint. + - `pl_checkpoint_path`, `pl_class` and `pl_kwargs` are provided, load a module from a checkpoint, and overwrite checkpoint arguments with kwargs. + + Additionally, there is `pl_stored_params_override`, which is a dictionary we use to update the module's hyperparameters when saving. + + Motivation: + Pytorch Lightning checkpoint saving stores the arguments passed to __init__ of the modules, and uses them to recreate the object. + If we are training a model that uses other PL modules, we usually want to pass in a path to a module checkpoint to restore module from. + (e.g., a bigger module uses pretrained encoder, and loads it from a checkpoint path). + However, when we save the module and try to restore it, it tries to use the same checkpoint path to load up the encoder, which probably doesn't exist. + (e.g., if we trained the bigger module on a different machine, and the current machine does not have the encoder checkpoint). + Ideally, we would want the bigger module contain everything to recreate the encoder, and not rely on a checkpoint path. + + This CreationArgs class aims to solve this by initially loading the encoder from a checkpoint path while training, but then + storing the arguments used to create the encoder in the bigger module, so that we can recreate the encoder from scratch when loading the bigger module. + + `pl_stored_params_override` can be used to replace custom hyperparameters of the module when saving, e.g. if we want to remove the checkpoint path. + """ + + def __init__( + self, pl_class: Type[pl.LightningModule], pl_checkpoint_path: Optional[str] = None, pl_stored_params_override: Optional[dict[str, Any]] = None, **pl_kwargs + ): + assert pl_class is not None, "Must provide a class for the PL module." + self.pl_class = pl_class + self.pl_checkpoint_path = pl_checkpoint_path + self.pl_kwargs = pl_kwargs + self.pl_stored_params_override = pl_stored_params_override + + @classmethod + def from_dict(self, config: dict[str, Any], class_name_to_class_fn: Callable[[str], Type[pl.LightningModule]]) -> "LightningModuleCreationArgs": + """ + Create a LightningModuleCreationArgs object from a config dict. + The config dictionary should have the following entries: + - `__class_name__`: str, name of the PL module class to create. This is passed to the `class_name_to_class_fn` function. + - `__checkpoint_path__`: The checkpoint path to load the PL module from (optional) + - Rest of the arguments are passed as **kwargs to the constructor + + Args: + config: The config dictionary. + class_name_to_class_fn: A function, mapping class names to classes. + Returns: + A LightningModuleCreationArgs object. + """ + assert "__class_name__" in config, "Must provide a class name for the PL module as `__class_name__`." + pl_class = class_name_to_class_fn(config["__class_name__"]) + checkpoint_path = config.get("__checkpoint_path__", None) + stored_params_override = config.get("__stored_hparams_override__", None) + + kwargs = config.copy() + del kwargs["__class_name__"] + if "__checkpoint_path__" in kwargs: + del kwargs["__checkpoint_path__"] + if "__stored_hparams_override__" in kwargs: + del kwargs["__stored_hparams_override__"] + + return LightningModuleCreationArgs(pl_class=pl_class, pl_checkpoint_path=checkpoint_path, pl_stored_params_override=stored_params_override, **kwargs) + + def create_module(self, remove_checkpoint_path: bool = True, **kwargs) -> pl.LightningModule: + """ + Create the PL module based on arguments: + - If `pl_checkpoint_path` is provided, load a module from a checkpoint. + - Otherwise, create a new module with `pl_class` and `pl_kwargs`. + + If `remove_checkpoint_path` is True, then the creation kwargs will be updated to match the checkpoint arguments, + and checkpoint path will be removed. This is to make loading nested models easier. + + **kwargs will be used to overwrite `pl_kwargs` if both are provided. + """ + pl_kwargs = self.pl_kwargs.copy() + pl_kwargs.update(kwargs) + if self.pl_checkpoint_path is not None: + pl_module = self.pl_class.load_from_checkpoint(self.pl_checkpoint_path, **pl_kwargs) + if remove_checkpoint_path: + self.update_creation_args_to_match_module(pl_module) + else: + # "remove_checkpoint_path" does not have an effect here, as all settings should be in the kwargs already. + pl_module = self.pl_class(**pl_kwargs) + # However, if we args we want to override the stored hparams, we can do that here. + if self.pl_stored_params_override is not None: + if pl_module.hparams is None: + # We have to make sure hyperparameters are stored in the module, otherwise they won't be saved. + pl_module.save_hyperparameters() + self.pl_kwargs.update(self.pl_stored_params_override) + pl_module.hparams.update(self.pl_kwargs) + self.pl_stored_params_override = None + + return pl_module + + def update_creation_args_to_match_module(self, module: pl.LightningModule) -> None: + """ + Update this object's creation arguments to match those of the saved model. + If this module already has kwargs, then these will be used to overwrite the checkpoint arguments. + Also removes the checkpoint path, if it exists. + + This is intended to be used to update the creation arguments of a model loaded from a checkpoint, + so that when creating nested models, we keep all hyperparameters in a single place + """ + assert isinstance(module, self.pl_class), f"Module is not of type {self.pl_class}" + module_hparams = module.hparams + if self.pl_kwargs is None: + self.pl_kwargs = dict() + module_hparams.update(self.pl_kwargs) + # More of a typing thing, but lets make sure all entries are valid arguments + # for model (i.e. keys are strings) + new_pl_kwargs = dict() + for key, value in module_hparams.items(): + assert isinstance(key, str), f"Key {key} in module hyperameters is not a string" + new_pl_kwargs[key] = value + self.pl_kwargs = new_pl_kwargs + self.pl_checkpoint_path = None diff --git a/wham/models/wham_base/predictor.py b/wham/models/wham_base/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..677575be1ff020023508830dbcaa6041ceaf2281 --- /dev/null +++ b/wham/models/wham_base/predictor.py @@ -0,0 +1,180 @@ +from typing import Any, Optional, Union + +import torch as th +import pytorch_lightning as pl +from tensordict import TensorDict # type: ignore # requires installing stubs for tensordict + +# Using relative import so that this module is easier to move from this place to elsewhere +from .tensor_spaces import TensorSpace, TensorDictSpace + + +def assert_space_tuple_is_valid(space_tuple: tuple[TensorSpace]) -> None: + """ + Checks that `space_tuple` is a valid tuple of TensorSpace objects. + An empty tuple is valid. + """ + if not isinstance(space_tuple, tuple): + raise ValueError(f"Space tuple {space_tuple} is not a tuple") + + for space in space_tuple: + assert isinstance(space, TensorSpace), f"Space {space} is not of type TensorSpace" + + +def assert_check_that_space_is_in_valid_spaces(space: TensorSpace, valid_spaces: tuple[TensorSpace]) -> None: + """ + Checks that `space` is in `valid_spaces`, and raises an error if not. + """ + if len(valid_spaces) == 0: + raise ValueError(f"Valid spaces is empty, so {space} cannot be in it") + + for valid_space in valid_spaces: + assert isinstance(valid_space, TensorSpace), f"Valid space {valid_space} is not of type TensorSpace" + if space.is_subset_of(valid_space): + return + raise ValueError(f"Space {space} is not a subset of any of the valid spaces {valid_spaces}") + + +def assert_check_that_space_dict_is_in_valid_spaces(space_dict: TensorDictSpace, valid_spaces: tuple[TensorSpace]) -> None: + """ + Checks that `space_dict` is a valid dictionary defining space name to TensorSpace, and that all spaces in the dict + can part of `valid_spaces`. + """ + if not isinstance(space_dict, TensorDictSpace): + raise ValueError(f"Space dict {space_dict} is not a TensorDictSpace") + + for space_name, space in space_dict.items(): + space = space_dict[space_name] + assert isinstance(space_name, str), f"Space name {space_name} is not a string" + assert isinstance(space, TensorSpace), f"Space {space} is not of type TensorSpace" + assert_check_that_space_is_in_valid_spaces(space, valid_spaces) + + +def assert_check_that_tensor_dict_is_valid_for_spaces(tensor_dict: TensorDict, spaces: TensorDictSpace, n_preceding_dims: Optional[int]) -> None: + """ + Checks that input dictionary is a valid instantiation of the given spaces. + - tensor_dict keys should match or be subset of spaces keys. + - All values in tensor_dict should be in the corresponding space + - All tensors have the same preceding dimensions (e.g., batch or batch and time) + - Number of preceding dimensions is equal to `n_preceding_dims`, if provided + + Raises an error if any of these checks fail. + + Args: + tensor_dict: TensorDict mapping space names to tensors. + spaces: Dictionary mapping space names to TensorSpace + n_preceding_dims: Number of preceding dimensions that all tensors should have (if None, this is not checked) + """ + if not isinstance(tensor_dict, TensorDict): + raise ValueError("Input dictionaries should be instances of tensordict.TensorDict") + + if not set(tensor_dict.keys()).issubset(set(spaces.keys())): + raise ValueError(f"Input dict keys {tensor_dict.keys()} does not match or is not a subset of space keys {spaces.keys()}") + + expected_preceding_dims = None + for tensor_name, tensor_tensor in tensor_dict.items(): + assert isinstance(tensor_name, str), f"Input name {tensor_name} is not a string" + assert isinstance(tensor_tensor, th.Tensor), f"Tensor {tensor_tensor} is not a tensor" + + space = spaces[tensor_name] + if not space.contains(tensor_tensor): + raise ValueError(f"Tensor {tensor_name} is not in space {space}") + + preceding_dims = space.get_preceding_dimensions(tensor_tensor) + + if n_preceding_dims is not None and len(preceding_dims) != n_preceding_dims: + raise ValueError(f"Tensor {tensor_name} has {len(preceding_dims)} preceding dimensions, but expected {n_preceding_dims}") + + if expected_preceding_dims is None: + expected_preceding_dims = preceding_dims + else: + assert preceding_dims == expected_preceding_dims, f"Tensor {tensor_name} has preceding dims {preceding_dims}, but expected {expected_preceding_dims}" + + +class PredictorBase(pl.LightningModule): + """ + Base class for "predictor" torch modules, which are used to predict future states from a context. + + Args: + context_space: A TensorDictSpace defining the space for each context modality. + Must have at least one context modality. + condition_space: A dictionary mapping condition names to their spaces. + This may be empty or None if there are no conditions. + """ + + # This is a dictionary of keyword arguments that can be used to create this class + # during testing/quick debugging (e.g., minimal model size) + __DEBUG_CREATION_KWARGS__: dict[str, Any] = dict() + + # These class attributes are used to define the acceptable spaces for the predictor. + # They are used in the `__init__` method to check that the spaces user is trying to use are valid for this predictor + _acceptable_context_spaces: Optional[Union[TensorSpace, tuple[TensorSpace]]] = None + _acceptable_condition_spaces: Optional[Union[tuple[TensorSpace], tuple[TensorSpace]]] = None + + def __init__(self, context_space: TensorDictSpace, condition_space: Optional[TensorDictSpace] = None): + if condition_space is None: + condition_space = TensorDictSpace(dict()) + assert_check_that_space_dict_is_in_valid_spaces(context_space, self.acceptable_context_spaces) + assert_check_that_space_dict_is_in_valid_spaces(condition_space, self.acceptable_condition_spaces) + assert len(context_space) > 0, "There must be at least one context encoder space" + super().__init__() + + self.context_space = context_space + self.condition_space = condition_space + + @property + def acceptable_context_spaces(self) -> tuple[TensorSpace]: + self_class = self.__class__ + _acceptable_context_spaces = self_class._acceptable_context_spaces + assert ( + _acceptable_context_spaces is not None + ), f"Class {self_class} has no _acceptable_context_spaces class property defined. This should be a tuple of TensorSpace" + if not isinstance(_acceptable_context_spaces, tuple): + _acceptable_context_spaces = (_acceptable_context_spaces,) + try: + assert_space_tuple_is_valid(_acceptable_context_spaces) + except Exception as e: + raise AssertionError(f"Class {self_class} has an invalid _acceptable_context_spaces class property. See above for details") from e + return _acceptable_context_spaces + + @property + def acceptable_condition_spaces(self) -> tuple[TensorSpace]: + self_class = self.__class__ + _acceptable_condition_spaces = self_class._acceptable_condition_spaces + assert ( + _acceptable_condition_spaces is not None + ), f"Class {self_class} has no _acceptable_condition_spaces class property defined. This should be a tuple of TensorSpace" + if not isinstance(_acceptable_condition_spaces, tuple): + _acceptable_condition_spaces = (_acceptable_condition_spaces,) + try: + assert_space_tuple_is_valid(_acceptable_condition_spaces) + except Exception as e: + raise AssertionError(f"Class {self_class} has an invalid _acceptable_condition_spaces class property. See above for details") from e + return _acceptable_condition_spaces + + def assert_check_context_tensordict_is_valid(self, context_dict: TensorDict) -> None: + """ + Checks that context tensordict is a valid set of modalities for this predictor: + - context_dict should have all the modalities that this predictor expects + - All the tensors should be contained in the spaces that this predictor expects + - Preceding dimensions of all tensors should be the same, and have two die dimensions (batch and time) + + Raises an error if the context dictionary is invalid. + """ + try: + assert_check_that_tensor_dict_is_valid_for_spaces(context_dict, self.context_space, n_preceding_dims=2) + except Exception as e: + raise ValueError(f"Context TensorDict {context_dict} is not valid for this predictor. See above exception for more info.") from e + + def assert_check_condition_dict_is_valid(self, condition_dict: TensorDict) -> None: + """ + Checks that input dictionary is a valid set of modalities for this predictor: + - context_dict should have all the modalities that this predictor expects + - All the tensors should be contained in the spaces that this predictor expects + - Conditions should _only_ have batch dimension + + Raises an error if the input is invalid. + """ + try: + assert_check_that_tensor_dict_is_valid_for_spaces(condition_dict, self.condition_space, n_preceding_dims=1) + except Exception as e: + raise ValueError(f"Condition TensorDict {condition_dict} is not valid for this predictor. See above exception for more info.") from e diff --git a/wham/models/wham_base/tensor_spaces.py b/wham/models/wham_base/tensor_spaces.py new file mode 100644 index 0000000000000000000000000000000000000000..15f5b40966bbefb4c91995bb76ce6f8f3d1d6346 --- /dev/null +++ b/wham/models/wham_base/tensor_spaces.py @@ -0,0 +1,483 @@ +# Akin to gym.spaces but for pytorch + +from typing import Tuple, Union, Iterator, Optional + +import torch as th +from tensordict import TensorDict # type: ignore # requires installing stubs for tensordict + + +def assert_is_valid_space_shape_tuple(shape: Tuple[Union[int, None]]) -> None: + """ + Check if a tuple if valid for defining a space shape. + A space shape is a tuple of integers or "None". A "None" indicates a wildcard dimension, + + Args: + shape: The tuple to check. + """ + assert isinstance(shape, tuple), f"Input {shape} is not a tuple." + assert len(shape) > 0, f"Input tuple {shape} is not a valid shape. It must have at least one dimension." + for dim in shape: + if dim is not None and not isinstance(dim, int): + raise ValueError(f"Input tuple {shape} is not a valid shape. Entry '{dim}' is not an integer or None.") + + +def check_if_shape_is_in_space_shape(shape: Union[Tuple[Optional[int]], th.Size], space_shape: Tuple[Union[int, None]]) -> bool: + """ + Check if a shape is inside space shape. + + A shape is inside a space shape if the number of dimensions match and all dimensions are equal. + If space shape has a wildcard "None", then the corresponding dimension in the shape can be int or None. + If shape has a wildcard "None", then the corresponding dimension in the space shape must be None. + + Args: + shape: The tensor shape to check. + space_shape: The space shape to check. + + Returns: + True if the tensor shape is in the space shape, False otherwise. + """ + assert len(shape) == len(space_shape), "Input shapes must have the same number of dimensions." + for tensor_dim, space_dim in zip(shape, space_shape): + if space_dim is None: + continue + if tensor_dim != space_dim: + return False + return True + + +def prepend_dimensions(tensor: th.Tensor, dimensions_to_add: tuple) -> th.Tensor: + """ + Prepend dimensions to a tensor by repeating the tensor on the new dimensions. + E.g., if original shape is [3, 12, 12], and dimensions_to_add is (1, 2), then the output shape is [1, 2, 3, 12, 12]. + + Args: + tensor: The tensor to prepend dimensions to. + dimensions_to_add: tuple of dimensions to add. + + Returns: + The tensor with prepended dimensions. + """ + expected_shape = tuple(dimensions_to_add) + tensor.shape + repeat_dims = list(dimensions_to_add) + [1] * len(tensor.shape) + output = tensor.repeat(*repeat_dims) + assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {tensor.shape}." + return output + + +def convert_potential_scalar_to_full_tensor(tensor_or_scalar: Union[int, float, th.Tensor], shape: Tuple[Union[int, None]], dtype: th.dtype) -> th.Tensor: + """ + Convert a potential scalar to a broadcastable tensor. + This is used to convert "low" and "high" arguments to TensorSpace to tensors. + + If shape has Nones (wildcards), these will be replaced with dimension 1 in the output tensor. + This will allow broadcastable operations when doing comparisons + + Args: + tensor_or_scalar: The potential scalar to convert to a tensor. + shape: The shape of the tensor to convert to. + dtype: The dtype of the tensor to convert to. + + Returns: + The converted tensor. + """ + assert isinstance(shape, tuple), f"Input shape {shape} is not a tuple." + assert_is_valid_space_shape_tuple(shape) + assert isinstance(dtype, th.dtype), f"Input dtype {dtype} is not a torch dtype." + + target_shape = tuple(1 if dim is None else dim for dim in shape) + if isinstance(tensor_or_scalar, (int, float)): + output_tensor = th.full(target_shape, tensor_or_scalar, dtype=dtype) + elif isinstance(tensor_or_scalar, th.Tensor): + # Make sure the tensor is broadcastable + assert tensor_or_scalar.shape == target_shape, f"Input tensor shape {tensor_or_scalar.shape} is not broadcastable to shape {shape}. None's should be 1's." + assert tensor_or_scalar.dtype == dtype, f"Input tensor dtype {tensor_or_scalar.dtype} is not expected dtype {dtype}." + output_tensor = tensor_or_scalar + else: + raise ValueError(f"Input {tensor_or_scalar} is not a scalar or tensor.") + return output_tensor + + +class TensorSpace: + """ + Base class for defining a space for pytorch tensors. Similar to gym.spaces, but for pytorch. + A space can be used to define the accepted/expected shape and dtype of a tensor. + + NOTE: tensors can have any number of preceding dimensions, but the last dimensions must match. + E.g. if space defines a 3D image of (channel, height, width), then a 4D tensor of + (batch, channel, height, width) is accepted, but a 4D tensor of (batch, height, width, channel) + is not accepted. + + Args: + shape: The expected shape of the tensor. This is tuple of integers or "None". A "None" indicates a wildcard + dimension, i.e. any integer is accepted. + dtype: The expected dtype of the tensor. This is strictly enforced*, i.e. even if the dtype was "castable" + to the expected dtype, it will not be accepted. + (*) Special case: float32 space can be casted to float16 in `contains` function to allow mixed precision training. + low: The lower bound of the tensor (inclusive, optional). Can be a scalar or torch tensor of the same shape as "shape" + high: The upper bound of the tensor (inclusive, optional). Can be a scalar or torch tensor of the same shape as "shape" + """ + + def __init__( + self, + shape: Tuple[Union[int, None]], + dtype: th.dtype = th.float, + low: Optional[Union[int, float, th.Tensor]] = None, + high: Optional[Union[int, float, th.Tensor]] = None, + ): + assert_is_valid_space_shape_tuple(shape) + assert isinstance(dtype, th.dtype), f"Input dtype {dtype} is not a torch dtype." + if low is not None: + low = convert_potential_scalar_to_full_tensor(low, shape, dtype) + if high is not None: + high = convert_potential_scalar_to_full_tensor(high, shape, dtype) + if low is not None and high is not None: + assert th.all(low <= high), f"Input low {low} is not <= high {high}." + + self._shape = shape + self._ndim = len(shape) + self._dtype = dtype + self._low = low + self._high = high + self.device = th.device("cpu") + self.to(self.device) + + def _check_and_move_device_if_necessary(self, tensor_or_space: Union[th.Tensor, "TensorSpace"]) -> None: + """ + Check this TensorSpace is in right device to do operations with the tensor. + If not, move this TensorSpace to the device of the tensor. + + Args: + tensor: The tensor to check for. + """ + if tensor_or_space.device != self.device: + self.to(tensor_or_space.device) + + @property + def shape(self) -> Tuple[Union[int, None]]: + return self._shape + + @property + def dtype(self) -> th.dtype: + return self._dtype + + @property + def low(self) -> Optional[th.Tensor]: + return self._low + + @property + def high(self) -> Optional[th.Tensor]: + return self._high + + def __eq__(self, other: object) -> bool: + assert isinstance(other, TensorSpace), f"Input {other} is not a TensorSpace." + if self.shape != other.shape: + return False + if self.dtype != other.dtype: + return False + + if isinstance(self.low, th.Tensor) and isinstance(other.low, th.Tensor): + if not th.all(self.low == other.low): + return False + elif self.low != other.low: + return False + + if isinstance(self.high, th.Tensor) and isinstance(other.high, th.Tensor): + if not th.all(self.high == other.high): + return False + elif self.high != other.high: + return False + + return True + + def to(self, device: th.device) -> "TensorSpace": + """ + Move the space to a device. + + Args: + device: The device to move the space to. + + Returns: + The space on the device. + """ + + if self._low is not None: + self._low = self._low.to(device) + if self._high is not None: + self._high = self._high.to(device) + self.device = device + return self + + def __repr__(self) -> str: + return f"(TensorSpace shape={self.shape}, dtype={self.dtype}, low={self.low}, high={self.high})" + + def contains(self, tensor: th.Tensor) -> bool: + """ + Check if a tensor is in the space. + + Args: + tensor: The tensor to check. + + Returns: + True if the tensor is in the space, False otherwise. + """ + if tensor.ndim < self._ndim: + return False + # Only check the trailing dimensions + x_shape = tensor.shape[-self._ndim :] + if not check_if_shape_is_in_space_shape(x_shape, self._shape): + return False + + if tensor.dtype != self.dtype: + # Special case for mixed-precision training: allow float32 to be casted to float16 + if not (self.dtype == th.float32 and tensor.dtype == th.float16): + return False + + self._check_and_move_device_if_necessary(tensor) + if self._low is not None and th.any(tensor < self._low): + return False + if self._high is not None and th.any(tensor > self._high): + return False + + return True + + def is_subset_of(self, other: "TensorSpace") -> bool: + """ + Check if this space is a subset of another space. + + A subset "sub" is subset of "super" if: + - Dtypes are the same + - Bounds of "sub" are within bounds of "super" + - Shapes are the same, except for wildcards + + NOTE: Space is subset of itself. + + Args: + other: The other space to check. + + Returns: + True if this space is a subset of the other space, False otherwise. + """ + assert isinstance(other, TensorSpace), f"Input {other} is not a TensorSpace." + + if self.dtype != other.dtype: + return False + + # Manually check ndim so we can return False + if self._ndim != other._ndim: + return False + + if not check_if_shape_is_in_space_shape(self.shape, other.shape): + return False + + self._check_and_move_device_if_necessary(other) + # If other (super) does not have bounds, we can skip the checks. + # But if other has bounds, then this (sub) must have bounds as well. + if other.low is not None: + if self.low is None or th.any(self.low < other.low): + return False + if other.high is not None: + if self.high is None or th.any(self.high > other.high): + return False + + return True + + def sample(self, dimensions_to_prepend: Optional[tuple] = None) -> th.Tensor: + """ + Sample a random uniform tensor from the space. + If dimensions_to_prepend is not None, then the tensor will have these dimensions prepended. + (e.g., for batch and time dimensions) + + Args: + dimensions_to_prepend: The dimensions to prepend to the tensor. + + Returns: + A random tensor from the space. + """ + if self.low is None or self.high is None: + raise ValueError("Cannot sample from an unbounded space.") + if None in self.shape: + raise ValueError("Cannot sample from a space with wildcard dimensions.") + tensor_shape: tuple[int] = self.shape # type: ignore # mypy thinks we can still have Nones + if self.dtype.is_floating_point: + random_tensor = th.rand(size=tensor_shape) * (self.high - self.low) + self.low + else: + high_plus_one = self.high + 1 + random_tensor = th.rand(size=tensor_shape) * (high_plus_one - self.low) + self.low + random_tensor = th.floor(random_tensor) + random_tensor = random_tensor.to(self.dtype) + if dimensions_to_prepend is not None: + assert isinstance(dimensions_to_prepend, tuple), f"Input dimensions_to_prepend {dimensions_to_prepend} is not a tuple." + random_tensor = prepend_dimensions(random_tensor, dimensions_to_prepend) + return random_tensor + + def get_preceding_dimensions(self, tensor: th.Tensor) -> tuple[int, ...]: + """ + Return the preceding dimensions of a tensor that are not part of the space. + Most commonly, these are the batch and time dimensions. + + E.g., if the space shape is (3, 4) and the tensor is (5, 3, 4), then the preceding dimensions are (5,). + + Args: + tensor: The tensor to check. + Returns: + The preceding dimensions of the tensor as a tuple. + """ + assert self.contains(tensor), f"Tensor {tensor} is not in the space {self}." + return tuple(tensor.shape[: -self._ndim]) + + +class TensorDictSpace: + """ + Akin to TensorDict but for TensorSpaces. + Holds a collection of TensorSpaces, each with a unique key. Operations are broadcast over keys. + """ + + def __init__(self, tensor_spaces: dict[str, TensorSpace]): + assert isinstance(tensor_spaces, dict), f"Input tensor_spaces {tensor_spaces} is not a dict." + for key, tensor_space in tensor_spaces.items(): + assert isinstance(key, str), f"Key {key} is not a string." + assert isinstance(tensor_space, TensorSpace), f"Value {tensor_space} is not a TensorSpace." + + self._tensor_spaces = tensor_spaces + + @property + def tensor_spaces(self) -> dict[str, TensorSpace]: + return self._tensor_spaces + + def __getitem__(self, key: str) -> TensorSpace: + return self.tensor_spaces[key] + + def __len__(self) -> int: + return len(self.tensor_spaces) + + def __iter__(self) -> Iterator[str]: + return self.keys() + + def keys(self) -> Iterator[str]: + return iter(self.tensor_spaces.keys()) + + def items(self) -> Iterator[Tuple[str, TensorSpace]]: + return iter(self.tensor_spaces.items()) + + def __repr__(self) -> str: + return f"(TensorDictSpace tensor_spaces={self.tensor_spaces})" + + def _check_tensordict_keys(self, tensor_dict: TensorDict, allow_key_subset: bool = False) -> bool: + """Check if input tensordict keys match the space keys, given the settings. Returns True if check passes, False otherwise.""" + tensor_dict_keys = set(tensor_dict.keys()) + tensor_spaces_keys = set(self.tensor_spaces.keys()) + if allow_key_subset: + # TensorDict should not have extra keys, but can miss some of the space keys + if not tensor_dict_keys.issubset(tensor_spaces_keys): + return False + else: + # All keys should match + if tensor_dict_keys != tensor_spaces_keys: + return False + return True + + def contains(self, tensor_dict: TensorDict, allow_key_subset: bool = False) -> bool: + """ + Check if a TensorDict is in the space. + TensorDict must have the matching keys (allow_key_subset = False) and each tensor must be in the corresponding TensorSpace. + + You can relax the requirement of all keys by setting allow_key_subset=False. + If True, then only a subset of keys is required, but all tensors in the TensorDict must be in the corresponding TensorSpace. + The TensorDict should also not have extra keys outside this space. + + Args: + tensor_dict: The tensor dict to check. + allow_key_subset: Whether to check that the keys match. If False, allow only a subset of keys. + + Returns: + True if the tensor dict is in the space, False otherwise. + """ + assert isinstance(tensor_dict, TensorDict), f"Input {tensor_dict} is not a tensordict." + + if not self._check_tensordict_keys(tensor_dict, allow_key_subset): + return False + + for key, tensor in tensor_dict.items(): + if not self.tensor_spaces[key].contains(tensor): + return False + + return True + + def is_subset_of(self, other: "TensorDictSpace") -> bool: + """ + Check if this TensorDictSpace is a subset of another TensorDictSpace. + See `TensorSpace.is_subset_of` for more details. + This function repeats the check for each key in the TensorDictSpace, and returns True if all checks pass. + + Args: + other: the assumed superspace. + + Returns: + True if this space is a subset of the other space, False otherwise. + """ + assert isinstance(other, TensorDictSpace), f"Input other {other} is not a TensorDictSpace." + + if set(self.tensor_spaces.keys()) != set(other.tensor_spaces.keys()): + return False + + for key, tensor_space in self.tensor_spaces.items(): + if not tensor_space.is_subset_of(other.tensor_spaces[key]): + return False + + return True + + def sample(self, dimensions_to_prepend: Optional[tuple] = None) -> TensorDict: + """ + Sample a random tensor dict from the space. + Raises if any of the spaces are unbounded or have wildcard dimensions. + + If dimensions_to_prepend is not None, then the tensor will have these dimensions prepended. + (e.g., for batch and time dimensions). + + Args: + dimensions_to_prepend: The dimensions to prepend to the tensor. + + Returns: + TensorDict - a random tensor dict from the space. + """ + assert dimensions_to_prepend is None or isinstance(dimensions_to_prepend, tuple), f"Input dimensions_to_prepend {dimensions_to_prepend} is not a tuple or None." + if dimensions_to_prepend is None: + dimensions_to_prepend = tuple() + return TensorDict({key: tensor_space.sample(dimensions_to_prepend) for key, tensor_space in self.tensor_spaces.items()}, batch_size=dimensions_to_prepend) + + def get_preceding_dimensions(self, tensor_dict: TensorDict, allow_key_subset: bool = False) -> Optional[tuple[int, ...]]: + """ + Return the preceding dimensions of the tensors. All tensors must have the same preceding dimensions. + For TensorDicts, this corresponds to the "batch_dim" argument. + + If preceding dimensions are not the same, raises ValueError. + + You can relax the requirement of all keys by setting allow_key_subset=False. + If True, then only a subset of keys is required, but all tensors in the TensorDict must be in the corresponding TensorSpace. + The TensorDict should also not have extra keys outside this space. + + Args: + tensor_dict: The tensor dict to check. + allow_key_subset: Whether to check that the keys match. If False, allow only a subset of keys. + + Returns: + The preceding dimensions of the tensors as a tuple. + """ + assert isinstance(tensor_dict, TensorDict), f"Input tensor_dict {tensor_dict} is not a dict." + + if not self._check_tensordict_keys(tensor_dict, allow_key_subset): + raise ValueError(f"TensorDict {tensor_dict} does not have the same keys as TensorDictSpace {self}.") + + preceding_dimension = None + for key, tensor in tensor_dict.items(): + tensor_space = self.tensor_spaces[key] + if not tensor_space.contains(tensor): + raise ValueError(f"TensorDict {tensor_dict} does not have the same keys as TensorDictSpace {self}.") + tensor_preceding_dimensions = tensor_space.get_preceding_dimensions(tensor) + if preceding_dimension is None: + preceding_dimension = tensor_preceding_dimensions + + if tensor_preceding_dimensions != preceding_dimension: + raise ValueError(f"TensorDict {tensor_dict} has {tensor_preceding_dimensions} for preceding dimensions, expected {preceding_dimension}.") + + return preceding_dimension diff --git a/wham/models/wham_token_model/gpt_token_transformer_predictor.py b/wham/models/wham_token_model/gpt_token_transformer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c865cf876e709f6ee8fb33dc91605ce5fab8dadf --- /dev/null +++ b/wham/models/wham_token_model/gpt_token_transformer_predictor.py @@ -0,0 +1,355 @@ +import torch as th +from tensordict import TensorDict + +from wham.models.nn.nanoGPT import GPT, GPTConfig +from wham.models.wham_base.predictor import PredictorBase +from wham.models.wham_base.tensor_spaces import TensorSpace + +# These are from the Chincilla paper +# https://arxiv.org/abs/2203.15556 +GPT_MODEL_SIZES = { + "chinchilla-251M": {"n_layer": 16, "n_head": 16, "n_embd": 1024}, + "chinchilla-1018M": {"n_layer": 23, "n_head": 14, "n_embd": 1792}, + "wham-4kvocab-1.6b": {"n_layer": 24, "n_head": 18, "n_embd": 2304}, # Bad name, vocab size is actually 16K + "wham-3b": {"n_layer": 30, "n_head": 24, "n_embd": 3072}, +} + + +def interleave_seq_token_tensors(tensors): + """ + Interleaves tokens from different sequences. + For example, if we have tensors (states, actions), we want to interleave them + into a single sequence of tokens in style [state1, action1, state2, action2, ...] + + Inputs: + tensors: list of torch tensors of shape (batch_size, seq_len, num_tokens) + `num_tokens` can vary between tensors. + Outputs: + interleaved: tensor of shape (batch_size, seq_len * (num_tokens1 + num_tokens2 + ...)) + """ + assert all(tensor.ndim == 3 for tensor in tensors), "All tensors must be 3D" + interleaved = th.cat(tensors, dim=-1) + interleaved = interleaved.reshape(interleaved.shape[0], -1) + # Continuity is required for efficient memory access (and nn operations) + return interleaved.contiguous() + + +def deinterleave_seq_token_tensors(interleaved, num_tokens_per_tensor): + """ + Inverse of interleave_seq_token_tensors. + Takes in interleaved tensor of tokens (batch_size, seq_len * (num_tokens1 + num_tokens2 + ...)), + and returns a list of tensors of shape (batch_size, seq_len, num_tokens), + where `num_tokens` is specified by num_tokens_per_tensor (a list of integers). + + Inputs: + interleaved: tensor of shape (batch_size, seq_len * (num_tokens_per_tensor[0] + num_tokens_per_tensor[1] + ...)) + num_tokens_per_tensor: list of integers specifying the number of tokens per tensor + Outputs: + tensors: list of torch tensors of shape (batch_size, seq_len, num_tokens) + """ + assert interleaved.ndim == 2, "Interleaved tensor must be 2D" + num_tokens_per_step = sum(num_tokens_per_tensor) + num_tokens = interleaved.shape[-1] + + assert num_tokens % num_tokens_per_step == 0, "Interleaved tensor must be divisible by num_tokens_per_step" + seq_len = num_tokens // num_tokens_per_step + + matrix_interleaved = interleaved.reshape(-1, seq_len, num_tokens_per_step) + tensors = [] + start = 0 + for num_tokens in num_tokens_per_tensor: + tensors.append(matrix_interleaved[:, :, start : start + num_tokens]) + start += num_tokens + return tensors + + +def interleave_seq_token_embedding_tensors(tensors): + """ + Same as interleave_seq_token_tensors, but 4D tensors (instead of tokens, we have embedding vectors). + + Interleaves token embedding tensors (batch, seq_len, ?, embedding_dim) from different tensors. + Dimension ? is the number of tokens each item in different tensors have. + + This is same as interleave_seq_token_tensors, but for token embedding tensors (additional dimension). + + For example, if we have tensors (states, actions) in following shapes: + states: (batch_size, seq_len, tokens_per_state, embedding_dim), and + actions: (batch_size, seq_len, tokens_per_action, embedding_dim), + where each item (last two dimensions) represents as single state or action in multiple tokens. + This function interleaves them into a single sequence of tokens in order + [state1, action1, state2, action2, ...], + with the shape + (batch_size, seq_len * (tokens_per_state + tokens_per_action), embedding_dim) + + Inputs: + tensors: list of torch tensors of shape (batch_size, seq_len, num_tokens, embedding_dim) + `num_tokens` can vary between tensors. + Outputs: + interleaved: tensor of shape (batch_size, seq_len * (num_tokens1 + num_tokens2 + ...), embedding_dim) + """ + assert all(tensor.ndim == 4 for tensor in tensors), "All tensors must be 4D" + interleaved = th.cat(tensors, dim=2) + embedding_dim = interleaved.shape[-1] + interleaved = interleaved.reshape(interleaved.shape[0], -1, embedding_dim) + # Continuity is required for efficient memory access (and nn operations) + return interleaved.contiguous() + + +def create_nano_gpt_model(model_size, max_context_length_tokens, vocab_size=1, version=1, bias=True): + assert model_size in GPT_MODEL_SIZES, "Invalid model size" + gpt_config = GPTConfig() + gpt_config.vocab_size = vocab_size + gpt_model_size_conf = GPT_MODEL_SIZES[model_size] + gpt_config.n_layer = gpt_model_size_conf["n_layer"] + gpt_config.n_head = gpt_model_size_conf["n_head"] + gpt_config.n_embd = gpt_model_size_conf["n_embd"] + gpt_config.block_size = max_context_length_tokens + gpt_config.version = version + gpt_config.bias = bias + gpt_model = GPT(gpt_config) + + return gpt_model, gpt_config + + +class GPTTokenPredictor(PredictorBase): + """ + Modality predictor that works on token basis and predicts next tokens. + - Uses positional encoding per token + - Uses fixed ordering of modalities + - Autoregressive prediction + - Tokens for each modality are interleaved into a single sequence + - Each modality has its own set of tokens (i.e., no overlapping token indeces between modalities) + """ + + __DEBUG_CREATION_KWARGS__ = { + "model_spec": { + "model_size": "debug_small_width", + "seq_len": 4, + } + } + + _acceptable_context_spaces = (TensorSpace((None,), dtype=th.long),) + _acceptable_condition_spaces = tuple() + + def __init__(self, context_space, condition_space, model_spec): + super().__init__(context_space, condition_space) + self.save_hyperparameters() + + self.model_size = model_spec["model_size"] + self.seq_len = model_spec["seq_len"] + + # Number of tokens per modality, as a dict and as a list (in order for (de)interleaving)) + self.tokens_per_modality = {} + self.tokens_per_modality_list = [] + + # Each modality will have its own set of tokens. + # But also respect the individual separation of token ranges per modality. + # Modality name -> integer, telling how much we need to offset tokens for this modality + self.vocab_offset_per_modality = {} + # modality name -> Tuple[int, int], telling the range of tokens for this modality, for each token individually + self.vocab_range_per_modality = {} + # List of modality names, in the order they are interleaved into the sequence. + self.modality_order = [] + self.total_vocab_size = 0 + for name, space in self.context_space.items(): + assert space.high is not None and space.low is not None, "High and low must be specified for all context spaces (vocab size per modality)" + self.tokens_per_modality[name] = space.shape[0] + self.tokens_per_modality_list.append(space.shape[0]) + + low_tensor = space.low + high_tensor = space.high + + min_token = low_tensor.min().item() + assert min_token == 0, f"Lowest token of space {name} is {min_token}, but must be 0. This is for clarity and avoiding unused tokens" + max_token = high_tensor.max().item() + + self.vocab_offset_per_modality[name] = self.total_vocab_size + self.vocab_range_per_modality[name] = tuple( + (int(low.item()) + self.total_vocab_size, int(high.item()) + self.total_vocab_size) for low, high in zip(low_tensor, high_tensor) + ) + # +1 because high is inclusive, and we ensured that low is 0 + self.total_vocab_size += max_token + 1 + + self.modality_order.append(name) + + # Adjust the tokens allowed when generating the very first image token + # All we are doing is disallowing the 0th token + # NOTE: THIS IS *NOT* required at all + # However, it provides a better experience when using the 200M model trained on 128x128 images + # This is a bit fiddly since tuples don't support item assignment + list_version_of_tuple = list(self.vocab_range_per_modality["images"]) + list_version_of_tuple[0] = (list_version_of_tuple[0][0] + 1, list_version_of_tuple[0][1]) + self.vocab_range_per_modality["images"] = tuple(list_version_of_tuple) + # End of adjustment for the first image token + + self.total_tokens_per_step = sum(self.tokens_per_modality.values()) + self.seq_len_in_tokens = self.seq_len * self.total_tokens_per_step + self.seq_len_in_tokens_for_inference = (self.seq_len - 1) * self.total_tokens_per_step + + gpt_version = model_spec.get("nanogpt_version", 1) + gpt_bias = model_spec.get("bias", True) + print(f"Creating NanoGPT model with version {gpt_version}") + self.gpt_model, self.gpt_config = create_nano_gpt_model(self.model_size, vocab_size=self.total_vocab_size, max_context_length_tokens=self.seq_len_in_tokens, version=gpt_version, bias=gpt_bias) + + def parameters(self): + return self.gpt_model.parameters() + + def _create_gpt(self, model_size, vocab_size, max_context_length_tokens): + assert model_size in GPT_MODEL_SIZES, "Invalid model size" + gpt_config = GPTConfig() + gpt_config.vocab_size = vocab_size + gpt_model_size_conf = GPT_MODEL_SIZES[model_size] + gpt_config.n_layer = gpt_model_size_conf["n_layer"] + gpt_config.n_head = gpt_model_size_conf["n_head"] + gpt_config.n_embd = gpt_model_size_conf["n_embd"] + gpt_config.block_size = max_context_length_tokens + gpt_model = GPT(gpt_config) + + return gpt_model, gpt_config + + def _check_not_too_long_context_length(self, tokens, num_tokens_to_be_generated): + """ + Check that the context length is not too long for the amount we are trying to generate + """ + if (tokens.shape[1] + num_tokens_to_be_generated) > self.seq_len_in_tokens: + raise ValueError( + f"Trying to generate too many tokens given the context. Context {tokens.shape[1]} should be less than {self.seq_len_in_tokens} - {num_tokens_to_be_generated}" + ) + + def _interleave_and_offset_modalities(self, modalities): + """ + Interleave and offset tokens from different modalities into a single sequence. + Offset tokens of each modality so that different modalities do not overlap. + + Assumes modalities is already checked to be valid input for the model. + """ + modality_list = [modalities[name] + self.vocab_offset_per_modality[name] for name in self.modality_order] + interleaved_tokens = interleave_seq_token_tensors(modality_list) + return interleaved_tokens + + def _deinterleave_and_offset_tokens(self, interleaved_tokens): + """ + Inverse of _interleave_and_offset_modalities + """ + modality_list = deinterleave_seq_token_tensors(interleaved_tokens, self.tokens_per_modality) + modality_list = {name: modality_list[i] - self.vocab_offset_per_modality[name] for i, name in enumerate(self.modality_order)} + return modality_list + + def predict_n_tokens(self, tokens, n_tokens, valid_token_ranges, deterministic=False, temperature=1.0, top_k=None, top_p=None, min_tokens_to_keep=1): + """ + Given a sequence of tokens, predict the next action. + Returns new list of tokens with the predicted action appended. + + Inputs: + tokens: torch tensor (batch_size, seq_len) + n_tokens: int, number of tokens to predict + valid_token_ranges: Tuple[int, int] of valid vocab indices to predict from for each token (inclusive on both sides) + **kwargs: kwargs for gpt_model.optimized_generate + """ + self._check_not_too_long_context_length(tokens, n_tokens) + assert n_tokens == len( + valid_token_ranges + ), f"Must have a valid token range for each token to be generated. Expected {n_tokens}, got valid_token_ranges of length {len(valid_token_ranges)}" + new_tokens = self.gpt_model.optimized_generate( + tokens, + n_tokens, + valid_token_ranges=valid_token_ranges, + raise_cropping=True, + deterministic=deterministic, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_tokens_to_keep=min_tokens_to_keep, + ) + return new_tokens + + def cross_entropy_prediction_loss_on_tokens(self, token_seq, loss_mask): + """ + Given a sequence of tokens, try to predict next tokens on every timestep given all previous timesteps. + Returns average loss. + + Inputs: + token_seq: torch tensor (batch_size, token_seq_len) + dtype: th.long (0 <= token < vocab_size) + mask: torch tensor (batch_size, token_seq_len). 1 if timestep is valid, 0 if timestep is padding + Outputs: + losses: loss per timestep (batch_size, token_seq_len), where first timestep loss is 0 (padded) + """ + inputs = token_seq[:, :-1].contiguous() + targets = token_seq[:, 1:].contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + _, losses = self.gpt_model(inputs, targets=targets, loss_mask=loss_mask, loss_reduction="none") + # Pad from the left with zeros (there is no valid target for the first step) + losses = th.cat([th.zeros_like(losses[:, :1]), losses], dim=1) + return losses + + def cross_entropy_prediction_loss(self, modalities, loss_mask): + """ + Given a TensorDict of sequence of different modalities as tokens, interleave them into a single sequence + and try to predict next tokens on every timestep given all previous timesteps. + Returns average loss. + + Inputs: + modalities: TensorDict of modality name -> (batch_size, seq_len, tokens) + loss_mask: torch tensor (batch_size, seq_len). 1 if timestep is valid, 0 if timestep is padding + Outputs: + modality_losses: dictionary of losses per modality (batch_size, seq_len), where first timestep loss is 0 (padded) + n_valid_tokens: number of valid tokens in the loss mask + """ + self.assert_check_context_tensordict_is_valid(modalities) + interleaved_tokens = self._interleave_and_offset_modalities(modalities) + # Mask is just repeated the same number of times as there are tokens per timestep + loss_mask = loss_mask.repeat_interleave(dim=1, repeats=self.total_tokens_per_step) + losses = self.cross_entropy_prediction_loss_on_tokens(interleaved_tokens, loss_mask) + + # Split loss into different modalities + split_losses = deinterleave_seq_token_tensors(losses, self.tokens_per_modality_list) + + modality_losses = {name: loss for name, loss in zip(self.modality_order, split_losses)} + + # Compute average loss. + # To match with the previous numbers, where whole loss was summed over all tokens, + # we divide by the number of all valid tokens, not just the number of timesteps. + num_valid_tokens = loss_mask.sum() + return modality_losses, num_valid_tokens + + def predict_next_step(self, modalities, modalities_to_predict=None, **kwargs): + """ + Given a TensorDict of sequence of different modalities as tokens, predict the tokens for the next step. + + Inputs: + modalities: TensorDict of modality name -> (batch_size, seq_len, tokens) + modalities_to_predict: list of modalities to predict. If None, predict all modalities + NOTE: modalities_to_predict must be a subset of self.modality_order and in the same order. + e.g., if model was trained to predict steps in order [image, action], you can not predict + "action" first, as the model requires the image tokens first. + **kwargs: kwargs for gpt_model.optimized_generate + Outputs: + predicted_modalities: TensorDict of predicted modalities + all_tokens: tensor (batch_size, seq_len, tokens) of all tokens, including predicted ones + """ + self.assert_check_context_tensordict_is_valid(modalities) + # We have to manually avoid cutting down on context, as otherwise inference would fail (first token in context + # has to _always_ be first token of an image). + if modalities.shape[1] == self.seq_len: + modalities = modalities[:, 1:] + + all_tokens = self._interleave_and_offset_modalities(modalities) + predicted_tokens = dict() + + modalities_to_predict = modalities_to_predict or self.modality_order + for desired_modality, modality_name in zip(modalities_to_predict, self.modality_order): + assert ( + desired_modality == modality_name + ), f"Modalities to predict {modalities_to_predict} was in wrong order. Must follow the ordering of {self.modality_order}" + + for modality_name in self.modality_order: + tokens_to_predict = self.tokens_per_modality[modality_name] + all_tokens = self.predict_n_tokens(all_tokens, tokens_to_predict, self.vocab_range_per_modality[modality_name], **kwargs) + predicted_tokens[modality_name] = all_tokens[:, -tokens_to_predict:] - self.vocab_offset_per_modality[modality_name] + # Add time dimension + predicted_tokens[modality_name] = predicted_tokens[modality_name].unsqueeze(1) + + batch_dimension = self.context_space.get_preceding_dimensions(modalities)[0] + predicted_modalities = TensorDict(predicted_tokens, batch_size=(batch_dimension, 1)) + return predicted_modalities, tokens_to_predict diff --git a/wham/models/wham_token_model/token_action_encoder.py b/wham/models/wham_token_model/token_action_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..561c174229c9a0c800d6f52e750fb44b79c47eb8 --- /dev/null +++ b/wham/models/wham_token_model/token_action_encoder.py @@ -0,0 +1,136 @@ +import torch as th + +from wham.models.wham_base.tensor_spaces import TensorSpace +from wham.models.wham_base.encoder_decoder import EncoderDecoderBase + +# Each binary button will get on/off token associated with that button. +# This is to replicate original world model experiments. +# Potentially, could also try having only one on and one off token which +# is used by all buttons. +VOCAB_SIZE_FOR_BUTTON = 2 + +MAX_BUTTONS = 12 +POS_ACTIONS = 4 +POS_CLASSES = 11 + + +def get_valid_token_range_for_action_idx(action_idx, n_bins, token_offset=0): + """ + Given index of action token, return the range of valid token indices + for that action index, inclusive on both sides + + Inputs: + action_idx: index of action token + n_bins: number of bins used for stick discretization + token_offset: offset to add to token indices + Outputs: + valid_token_range: (min_token_idx, max_token_idx) + """ + if action_idx < MAX_BUTTONS: + # Button + min_token_idx = action_idx * VOCAB_SIZE_FOR_BUTTON + max_token_idx = min_token_idx + VOCAB_SIZE_FOR_BUTTON - 1 + else: + # Stick + min_token_idx = MAX_BUTTONS * VOCAB_SIZE_FOR_BUTTON + (action_idx - MAX_BUTTONS) * n_bins + max_token_idx = min_token_idx + n_bins - 1 + return min_token_idx + token_offset, max_token_idx + token_offset + + +def tokenize_actions(action_seq_batch, n_bins, token_offset=0): + """ + Tokenize BE actions into a sequence of tokens: + - Buttons are mapped to on/off tokens. Each button has its unique on/off tokens. + - Stick actions (which should be discrete) are mapped to unique tokens per stick. + + Inputs: + action_seq_batch: torch tensor (batch_size, seq_len, MAX_BUTTONS + POS_ACTIONS) + n_bins: number of bins used for stick discretization + token_offset: offset to add to token indices to avoid overlap with state tokens + Outputs: + action_seq_batch_discrete: (batch_size, seq_len, MAX_BUTTONS + POS_ACTIONS) + """ + # Make sure we get what we expect + assert action_seq_batch.shape[-1] == MAX_BUTTONS + POS_ACTIONS + action_token_seq_batch = th.zeros_like(action_seq_batch).long() + + # Buttons + total_token_offset = token_offset + for button_i in range(MAX_BUTTONS): + # Unique on/off token for every button + action_token_seq_batch[:, :, button_i] = (action_seq_batch[:, :, button_i] + button_i * VOCAB_SIZE_FOR_BUTTON + total_token_offset).long() + + total_token_offset += MAX_BUTTONS * VOCAB_SIZE_FOR_BUTTON + for action_index in range(MAX_BUTTONS, MAX_BUTTONS + POS_ACTIONS): + stick_index = action_index - MAX_BUTTONS + action_token_seq_batch[:, :, action_index] = (action_seq_batch[:, :, action_index] + stick_index * n_bins + total_token_offset).long() + + return action_token_seq_batch + + +def detokenize_actions(action_token_seq_batch, n_bins, token_offset=0): + """ + Reverse of tokenize_actions. See tokenize_actions for details. + Note that this returns discretized actions for sticks, which follow the discretization scheme of + rest of this repository (see e.g., data.parser.action_plugins) + """ + action_seq_batch = th.zeros_like(action_token_seq_batch).float() + + # Buttons + total_token_offset = token_offset + for button_i in range(MAX_BUTTONS): + action_seq_batch[:, :, button_i] = (action_token_seq_batch[:, :, button_i] - button_i * VOCAB_SIZE_FOR_BUTTON - total_token_offset).float() + + total_token_offset += MAX_BUTTONS * VOCAB_SIZE_FOR_BUTTON + # Assume rest are continuous actions + for action_index in range(MAX_BUTTONS, MAX_BUTTONS + POS_ACTIONS): + stick_index = action_index - MAX_BUTTONS + action_bin = action_token_seq_batch[:, :, action_index] - stick_index * n_bins - total_token_offset + action_seq_batch[:, :, action_index] = action_bin + return action_seq_batch + + +def get_action_vocab_size(bins_for_sticks): + """Return vocab size required by buttons""" + # Each button has 2 tokens (on/off), each stick has n_bins_for_sticks tokens (unique to every button/stick) + return MAX_BUTTONS * VOCAB_SIZE_FOR_BUTTON + POS_ACTIONS * bins_for_sticks + + +class ActionTokenEncoder(EncoderDecoderBase): + """ + Encoder for turning BE actions into sequence of tokens + """ + + __DEBUG_CREATION_KWARGS__ = dict() + + def __init__(self): + super().__init__() + self.n_bins_for_sticks = POS_CLASSES + self.vocab_size = get_action_vocab_size(self.n_bins_for_sticks) + + action_dim = MAX_BUTTONS + POS_ACTIONS + + # Original actions have buttons {0, 1} and then discretized positions [0, POS_CLASSES - 1] + world_space_lows = th.tensor([0] * MAX_BUTTONS + [0] * POS_ACTIONS, dtype=th.float) + world_space_highs = th.tensor([1] * MAX_BUTTONS + [POS_CLASSES - 1] * POS_ACTIONS, dtype=th.float) + self.world_space = TensorSpace((action_dim,), dtype=th.float, low=world_space_lows, high=world_space_highs) + + # In encoder space, each button has its own on/off token, and each stick has n_bins_for_sticks tokens + self._action_token_ranges = [get_valid_token_range_for_action_idx(i, self.n_bins_for_sticks) for i in range(action_dim)] + encoder_space_lows = th.tensor([r[0] for r in self._action_token_ranges], dtype=th.long) + encoder_space_highs = th.tensor([r[1] for r in self._action_token_ranges], dtype=th.long) + self.encoder_space = TensorSpace((action_dim,), dtype=th.long, low=encoder_space_lows, high=encoder_space_highs) + + def _encode(self, world_space_tensor): + """ + Encode BE actions into tokens + """ + assert world_space_tensor.ndim == 3, "ActionTokenEncoder only supports (batch, seq_len, action_dim) tensors" + return tokenize_actions(world_space_tensor, self.n_bins_for_sticks) + + def _decode(self, encoder_space_tensor): + """ + Decode tokens into BE actions + """ + assert encoder_space_tensor.ndim == 3, "ActionTokenEncoder only supports (batch, seq_len, action_dim) tensors" + return detokenize_actions(encoder_space_tensor, self.n_bins_for_sticks) diff --git a/wham/models/wham_token_model/wham_token.py b/wham/models/wham_token_model/wham_token.py new file mode 100644 index 0000000000000000000000000000000000000000..b2fbe3d814917ce3da1dc4b80bda2abe4c51321e --- /dev/null +++ b/wham/models/wham_token_model/wham_token.py @@ -0,0 +1,68 @@ +import torch +from wham.models.wham_base.encode_predict_decode_base import ( + EncodePredictDecodeModule, + create_encoder_modules_from_args, +) +from wham.models.wham_token_model.gpt_token_transformer_predictor import GPTTokenPredictor +from wham.models.vqgan.taming_vq_model import TamingVQModel +from wham.models.wham_token_model.token_action_encoder import ActionTokenEncoder +from wham.models.pl.pl_base_model import BaseTrainingModel + +LOSS_MASK_KEY = "loss_mask" + + +def class_name_to_model(class_name): + if class_name == "GPTTokenPredictor": + return GPTTokenPredictor + if class_name == "TamingVQModel": + return TamingVQModel + if class_name == "ActionTokenEncoder": + return ActionTokenEncoder + raise NotImplementedError(f"Model type {class_name} not implemented.") + + +class WHAMTokenModule(BaseTrainingModel, EncodePredictDecodeModule): + """A model that functions on a token level (e.g., combines all states and actions into one long sequence)""" + + def __init__(self, predictor_args, context_encoder_args, variant): + self.save_hyperparameters() + self.variant = variant + + context_encoders = create_encoder_modules_from_args(context_encoder_args) + # Freeze the context encoders + for context_encoder_param in context_encoders.parameters(): + context_encoder_param.requires_grad = False + + # Determine whether to use only the BC loss + self.bc_loss_only = variant["model_spec"].get("bc_loss_only", False) + + super().__init__(predictor_args=predictor_args, context_encoders=context_encoders) + + def predict_next_step(self, world_space_context, **kwargs): + """ + Predict the next step in the world space context. + + Args: + world_space_context (TensorDict): A TensorDict containing the world space context. + **kwargs: passed to predictor "predict_next_step" + Returns: + TensorDict: A TensorDict containing the predicted next step (batch, 1, ...) + """ + context = self.encode_context(world_space_context) + + # If we have tokens for an image, lets override their tokens + # Code is not great, but it gets the job done... + tokens = kwargs.get("tokens", None) + batch_size = context["images"].shape[0] + if tokens is not None: + for batch_idx in range(batch_size): + for timestep in range(context["images"][batch_idx].shape[0]): + if tokens[batch_idx][timestep] is not None: + tensored_tokens = torch.tensor(tokens[batch_idx][timestep], device=context["images"].device) + context["images"][batch_idx][timestep] = tensored_tokens + if "tokens" in kwargs: + del kwargs["tokens"] # We've used this, so remove it + predicted_next_step, _ = self.predictor.predict_next_step(context, **kwargs) + image_tokens = predicted_next_step["images"] + decoded_next_step = self.decode_context(predicted_next_step) + return decoded_next_step, image_tokens diff --git a/wham/utils.py b/wham/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7a740f04121e434d5e0c0ca95e992ce66cd4e6 --- /dev/null +++ b/wham/utils.py @@ -0,0 +1,25 @@ +import cv2 + +from wham.models.wham_token_model.wham_token import WHAMTokenModule + +# Hardcoded number (the FPS model was trained for) +DREAMING_FPS = 10 + +# Watermark configs +WATERMARK_TEXT = "Generated by WHAM" +PROGRAM_NAME = "Generated by WHAM (World Human-Action Model)" +WATERMARK_FONT = cv2.FONT_HERSHEY_SIMPLEX +WATERMARK_FONT_COLOR = (255, 255, 255) +WATERMARK_FONT_SCALE = 0.4 +WATERMARK_FONT_THICKNESS = 1 + +POS_BINS_BOUNDARIES = [-1.05, -0.95, -0.75, -0.5, -0.25, -0.05, 0.05, 0.25, 0.5, 0.75, 0.95, 1.05] + +POS_BINS_MIDDLE = [-1, -0.85, -0.625, -0.375, -0.15, 0, 0.15, 0.375, 0.625, 0.85, 1] + +def load_model_from_checkpoint(load_checkpoint_path): + # Need to do this trickery as we changed names of the module + import wham + import sys + sys.modules["humanmodelling"] = wham + return WHAMTokenModule.load_from_checkpoint(load_checkpoint_path, strict=False, map_location="cpu") \ No newline at end of file diff --git a/wham_demonstrator/Examples/0_facility_map_sample/actions.json b/wham_demonstrator/Examples/0_facility_map_sample/actions.json new file mode 100644 index 0000000000000000000000000000000000000000..65bd791e05764cbe52a56702664771bf6a624871 --- /dev/null +++ b/wham_demonstrator/Examples/0_facility_map_sample/actions.json @@ -0,0 +1 @@ +{"image0.png": [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.15000000596046448, 0.0, 0.0], "image1.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.625, 0.625, 0.0, 0.0], "image2.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8500000238418579, 0.0, 0.0], "image3.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8500000238418579, 0.0, 0.0], "image4.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8500000238418579, 0.0, 0.0], "image5.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8500000238418579, 0.0, 0.0], "image6.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8500000238418579, 1.0, 0.15000000596046448], "image7.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15000000596046448, 1.0, 1.0, 0.15000000596046448], "image8.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], "image9.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.15000000596046448, 1.0, 1.0, 0.15000000596046448]} \ No newline at end of file diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image0.png b/wham_demonstrator/Examples/0_facility_map_sample/image0.png new file mode 100644 index 0000000000000000000000000000000000000000..802f10d7d7e3d8fee0bd231b02378e4001b2cbfa Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image0.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image1.png b/wham_demonstrator/Examples/0_facility_map_sample/image1.png new file mode 100644 index 0000000000000000000000000000000000000000..972e07db3900f4631abf7a11bcaed5aff6e69536 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image1.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image2.png b/wham_demonstrator/Examples/0_facility_map_sample/image2.png new file mode 100644 index 0000000000000000000000000000000000000000..b474b43f956858a07e7b94c08743f0359b895fd8 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image2.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image3.png b/wham_demonstrator/Examples/0_facility_map_sample/image3.png new file mode 100644 index 0000000000000000000000000000000000000000..36b0cf4eae27b41e4c85d02798da4491d309dc6a Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image3.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image4.png b/wham_demonstrator/Examples/0_facility_map_sample/image4.png new file mode 100644 index 0000000000000000000000000000000000000000..03fe5717ecb2e173df06410e7e2f6b7143898a83 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image4.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image5.png b/wham_demonstrator/Examples/0_facility_map_sample/image5.png new file mode 100644 index 0000000000000000000000000000000000000000..85e1d85ec3af8c2cff0d317f322d08ee573b306f Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image5.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image6.png b/wham_demonstrator/Examples/0_facility_map_sample/image6.png new file mode 100644 index 0000000000000000000000000000000000000000..1f5ec7e8b36ab438f5664252d615b7a689a01425 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image6.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image7.png b/wham_demonstrator/Examples/0_facility_map_sample/image7.png new file mode 100644 index 0000000000000000000000000000000000000000..9edbad1bb7e9616c474d441102146fdd26704456 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image7.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image8.png b/wham_demonstrator/Examples/0_facility_map_sample/image8.png new file mode 100644 index 0000000000000000000000000000000000000000..12ef94ff7d517c8434c2afd0e426f502edbe8700 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image8.png differ diff --git a/wham_demonstrator/Examples/0_facility_map_sample/image9.png b/wham_demonstrator/Examples/0_facility_map_sample/image9.png new file mode 100644 index 0000000000000000000000000000000000000000..324a955152ffbd617376f0d61439f53aba690267 Binary files /dev/null and b/wham_demonstrator/Examples/0_facility_map_sample/image9.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/actions.json b/wham_demonstrator/Examples/0_wham-sample-junkyard/actions.json new file mode 100644 index 0000000000000000000000000000000000000000..397823e96c13c35220db41747d163f78752df8c0 --- /dev/null +++ b/wham_demonstrator/Examples/0_wham-sample-junkyard/actions.json @@ -0,0 +1 @@ +{"image0.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.625, 0.8500000238418579, 0.0, 0.0], "image1.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.625, 0.625, 0.0, 0.0], "image2.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image3.png": [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.15000000596046448, 0.0, 0.0], "image4.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8500000238418579, 0.15000000596046448, 0.0, 0.0], "image5.png": [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8500000238418579, 0.375, 0.0, 0.0], "image6.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.625, 0.8500000238418579, 0.0, 0.0], "image7.png": [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.625, 0.8500000238418579, 0.0, 0.0], "image8.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.625, 0.625, 0.0, 0.0], "image9.png": [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8500000238418579, 0.625, 0.0, 0.0]} \ No newline at end of file diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image0.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image0.png new file mode 100644 index 0000000000000000000000000000000000000000..f62d7f97ac36efbf06cc19f62bc99f99bb853d07 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image0.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image1.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image1.png new file mode 100644 index 0000000000000000000000000000000000000000..cd989a6c23420e0e5216a668c76723b401851130 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image1.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image2.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image2.png new file mode 100644 index 0000000000000000000000000000000000000000..484841d1c8855b4be6fe1d5bf5db9ba21597bc46 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image2.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image3.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image3.png new file mode 100644 index 0000000000000000000000000000000000000000..a0ad76f249e944e6108adcc910bbcc6b1ce3e59e Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image3.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image4.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image4.png new file mode 100644 index 0000000000000000000000000000000000000000..a3c91ee2c34aef5ad8043c76c6e83cf5b2e2cfd6 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image4.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image5.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image5.png new file mode 100644 index 0000000000000000000000000000000000000000..1749885ba81406401d589e1553792ae707af8883 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image5.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image6.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image6.png new file mode 100644 index 0000000000000000000000000000000000000000..c3e8b97cadbef95a62072a60b4ef7b16a77c0e39 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image6.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image7.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image7.png new file mode 100644 index 0000000000000000000000000000000000000000..e06bbb30790ebe08d70a6b31ba306a42a7f5fa78 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image7.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image8.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image8.png new file mode 100644 index 0000000000000000000000000000000000000000..fd84f7d6fd416bf127d5fa9f3738efdea1d34819 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image8.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-junkyard/image9.png b/wham_demonstrator/Examples/0_wham-sample-junkyard/image9.png new file mode 100644 index 0000000000000000000000000000000000000000..6a028ee2416921e15b1ac9be95fe8cf0ecd91efd Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-junkyard/image9.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/actions.json b/wham_demonstrator/Examples/0_wham-sample-landslide/actions.json new file mode 100644 index 0000000000000000000000000000000000000000..af8af43ebb262b4bb491f6a6a3fa783f8cfb2206 --- /dev/null +++ b/wham_demonstrator/Examples/0_wham-sample-landslide/actions.json @@ -0,0 +1 @@ +{"image0.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image1.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image2.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image3.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image4.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "image5.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8500000238418579, 0.625, 0.0, 0.0], "image6.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], "image7.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0], "image8.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.15000000596046448, -0.8500000238418579, 0.0, 0.0], "image9.png": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.375, -0.8500000238418579, 0.0, 0.0]} \ No newline at end of file diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image0.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image0.png new file mode 100644 index 0000000000000000000000000000000000000000..8b18036dc0cfc5bc01faa16b40fc4676cacce7d2 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image0.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image1.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image1.png new file mode 100644 index 0000000000000000000000000000000000000000..775153dc3b4afde0e6e5d570aff2313e754ccdcc Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image1.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image2.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image2.png new file mode 100644 index 0000000000000000000000000000000000000000..aa4e2c6e547a81b4b2088cda4828be466e55860c Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image2.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image3.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image3.png new file mode 100644 index 0000000000000000000000000000000000000000..24368b6aa4de3bc5ee81ba79d61f65b2db0b6a2c Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image3.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image4.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image4.png new file mode 100644 index 0000000000000000000000000000000000000000..0086fc8f974203ef7f52aaebc1121091148a4b17 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image4.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image5.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image5.png new file mode 100644 index 0000000000000000000000000000000000000000..c035ceb69cc64db7ae1679baffcb1429b130df72 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image5.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image6.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image6.png new file mode 100644 index 0000000000000000000000000000000000000000..328fb5c572c53078e165748b56fab257e66befc9 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image6.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image7.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image7.png new file mode 100644 index 0000000000000000000000000000000000000000..21cf914173391a7113c1bc36d20884af0bbdfae3 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image7.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image8.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image8.png new file mode 100644 index 0000000000000000000000000000000000000000..4321e99629736f61fe67db01f0a15ca7e578be9a Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image8.png differ diff --git a/wham_demonstrator/Examples/0_wham-sample-landslide/image9.png b/wham_demonstrator/Examples/0_wham-sample-landslide/image9.png new file mode 100644 index 0000000000000000000000000000000000000000..f83e1fdfc1ed1f2436bebc132e2e25e66e9d88d3 Binary files /dev/null and b/wham_demonstrator/Examples/0_wham-sample-landslide/image9.png differ diff --git a/wham_demonstrator/Examples/Layer_Image/0.png b/wham_demonstrator/Examples/Layer_Image/0.png new file mode 100644 index 0000000000000000000000000000000000000000..0b93c8ec92abdbbafd57386035a5cc71daa4113b Binary files /dev/null and b/wham_demonstrator/Examples/Layer_Image/0.png differ diff --git a/wham_demonstrator/Examples/Layer_Image/1.png b/wham_demonstrator/Examples/Layer_Image/1.png new file mode 100644 index 0000000000000000000000000000000000000000..386b8624649135825f27503c73a76001c6a56845 Binary files /dev/null and b/wham_demonstrator/Examples/Layer_Image/1.png differ diff --git a/wham_demonstrator/Examples/Layer_Image/2.png b/wham_demonstrator/Examples/Layer_Image/2.png new file mode 100644 index 0000000000000000000000000000000000000000..6b33c31777d8d880c3b60164675538e45309c31f Binary files /dev/null and b/wham_demonstrator/Examples/Layer_Image/2.png differ diff --git a/wham_demonstrator/Newtonsoft.Json.dll b/wham_demonstrator/Newtonsoft.Json.dll new file mode 100644 index 0000000000000000000000000000000000000000..d035c38b4edec5c10d4bc421a2dce19f5f998677 Binary files /dev/null and b/wham_demonstrator/Newtonsoft.Json.dll differ diff --git a/wham_demonstrator/Newtonsoft_LICENSE.txt b/wham_demonstrator/Newtonsoft_LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..05bc493e4d232ac37bf00726d4a0f2c335cfd4bc --- /dev/null +++ b/wham_demonstrator/Newtonsoft_LICENSE.txt @@ -0,0 +1,9 @@ +The MIT License (MIT) + +Copyright (c) 2007 James Newton-King + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/wham_demonstrator/README.md b/wham_demonstrator/README.md new file mode 100644 index 0000000000000000000000000000000000000000..044c17411d34efc67f6a730cd289c3eefd9a092c --- /dev/null +++ b/wham_demonstrator/README.md @@ -0,0 +1,79 @@ +> ### AI Generated content +> Models trained using game data may potentially behave in ways that are unfair, unreliable, or offensive, in turn causing harms. We emphasize that these types of harms are not mutually exclusive. A single model can exhibit more than one type of harm, potentially relating to multiple different groups of people. For example, the output of the model can be nonsensical or might look reasonable but is inaccurate with respect to external validation sources. + +# WHAM Demonstrator Instructions + +## Entering a Server IP Address +The model and API should be hosted on a server and port that is accessible to the app, and the machine running the app. When opening the app, enter the IP address with port into the text box at the top left of the window, in the form `http://127.0.0.1:5000`. You might need to click on Options for the Settings to be visible +![Figure 1](../assets/Demonstrator/Fig_01.png) + +## Opening Starting Frames +Creating sequences requires at least one start frame to generate gameplay sequence from. The WHAM Demonstrator contains some example starting frames. You can open the example frames using the “Open Example” button. +![Figure 2](../assets/Demonstrator/Fig_02.png) + +You can also start with either one or more images that you choose, or you can open a previously saved generated gameplay sequence timeline. +![Figure 3](../assets/Demonstrator/Fig_03.png) +To open one or more images to seed a generated gameplay sequence, select `File -> Open Image(s)…`, and select one or more images. The images will appear as the first images on the timeline at the bottom of the window. To open a previously saved timeline, select `File -> Open Timeline JSON…` and select the JSON file saved with the timeline. The seed images should now appear in the timeline at the bottom of the view. +![Figure 4](../assets/Demonstrator/Fig_04.png) + +## Prediction Parameters +When creating generated gameplay sequence, there are a number of parameters you can set: +![Figure 5](../assets/Demonstrator/Fig_05.png) +- **In**: The number of input frames into the next generated gameplay sequence. The maximum is 10. +- **Out**: The number of frames you would like to be returned per generated sequence. +- **Reps**: The number of branches for each generated gameplay sequence initiated. +- **Temp**: The temperature of the generated sequence. The higher this value, the more dynamic the generated gameplay sequence may be. + +## The Timeline User Interface +The timeline can get quite big. Here are a few controls to help manage the user interface when the timeline is large: +![Figure 6](../assets/Demonstrator/Fig_06.png) +- **Resize bar**: Click and drag this to change the vertical size of the timeline. +- **Frame selection**: If you want to focus in on a particular pathway, you can select the end frame, and all frames leading up to it will also be selected. +- **Selected frames only**: Toggle this to flatten the timeline to only the selected frames. +- **Zoom slider**: Use this to change the frame size. Note, smaller frames will be covered by the selection box. + +There is also a “snapshot” button that allows you to take a single picture of your entire timeline, even if it has scrolled out of view. +When you select a frame, you will see an outline around the frame in the timeline, and the frame will appear in the main area above. +![Figure 7](../assets/Demonstrator/Fig_07.png) + +**Note**: When making a frame the selected frame (and have the image appear above), you need to click on the image part of the frame thumbnail on the timeline, not the selection box. + +## Generating gameplay sequence +To create a generated gameplay sequence (or set of generated gamepla sequences), you must first click on the frame you would like to generate from. The WHAM Demonstrator allows creating new gameplay branches from any frame. So please ensure the frame you wish to generate from has been selected. If the parameters are correct and the server address has been entered, you can click “Predict” in the bar above the timeline. You should see the selected number of generated gameplay sequence branches, and new frames start to appear. +![Figure 8](../assets/Demonstrator/Fig_08.png) + +The frames may take a number of seconds to appear, depending on the hardware used to host the model. You can cancel one of the branches at any moment by clicking on the ‘X’ in the last frame square. When the frame generation is complete, you can select another frame anywhere in the timeline and select “Predict” again. +![Figure 9](../assets/Demonstrator/Fig_09.png) + +## Manipulating Images +New image layers can be added to any gameplay frame to introduce new content or characters to include in future generated frames. You could add a power-up or an NPC to a frame and have those included as the gameplay sequence evolves. + +**Note**: This feature is not fully supported yet in the models shared with the WHAM Demonstrator, so performance can be unpredictable. To help added elements really “stick” in the generated sequence, we recommend creating a sequence of 5 frames with your added element in before continuing to generate sequences. + +To add a new element, select the frame you wish to manipulate so it appears in the main frame area and click the `+` button in the layers panel. +![Figure 10](../assets/Demonstrator/Fig_10.png) + +You will then be prompted to select an image. We recommend a transparent PNG, like the example supplied within the Examples folder called "Layer_Image". The selected image will then appear both as a layer in the layers panel, and on the currently selected frame. +![Figure 11](../assets/Demonstrator/Fig_11.png) + +To move and resize the added element, click on the element, either in the layers panel, or on the image directly, and you can use your mouse wheel to resize, or drag it around to move. Here, the character has been placed on the right of the frame. +![Figure 12](../assets/Demonstrator/Fig_12.png) + +To easily create a sequence, click the “copy to next frame” button, and the layers will be added to the next frame for easier sequencing. Here are a further 4 frames showing the character movement as they enter from the right. +![Figure 13](../assets/Demonstrator/Fig_13.png) + +Predictions can now happen from the manipulated frame. + +**Note**: When saving a timeline, all frames are flattened, and so layer information will be lost. + +## Controller Input +The WHAM Demonstrator can also be controlled using an Xbox controller. +![Figure 14](../assets/Demonstrator/Fig_14.png) + +With a controller connected and a frame selected, you can hold down the buttons you wish to use on the controller for around a second and the desired input will be passed to the model for the next frame. When you hold buttons, a blue progress bar will appear below the controller, when this disappears, you must release all of the controller buttons. The WHAM Demonstrator does not yet support holding the buttons for multiple frame predictions. Also, in this mode, only one frame will be produced at a time. When you hold the buttons down on the controller, you will also notice that whenever a frame is selected, this controller image will show the action state of the controller too. + +## Saving Generated Gameplay Sequence +If you have created a gameplay sequence timeline that you want to save, or even continue later, you can save either a flat, selected sequence of frames, or the entire timeline. +![Figure 15](../assets/Demonstrator/Fig_15.png) + +Both of these options will ask for a folder and all of the required images and timeline information will be saved to that folder. \ No newline at end of file diff --git a/wham_demonstrator/SharpDX.XInput.dll b/wham_demonstrator/SharpDX.XInput.dll new file mode 100644 index 0000000000000000000000000000000000000000..8b1b014c54e593dd60bb2b4d3a80f12c436ddb01 Binary files /dev/null and b/wham_demonstrator/SharpDX.XInput.dll differ diff --git a/wham_demonstrator/SharpDX.dll b/wham_demonstrator/SharpDX.dll new file mode 100644 index 0000000000000000000000000000000000000000..f3b7388e604c35fab19b1776b247b55a4b1a90b7 Binary files /dev/null and b/wham_demonstrator/SharpDX.dll differ diff --git a/wham_demonstrator/SharpDX_LICENSE.txt b/wham_demonstrator/SharpDX_LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..c8adfecad7eb846e1be4bb64c48cfe825aebb4c1 --- /dev/null +++ b/wham_demonstrator/SharpDX_LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (c) 2010-2014 SharpDX - Alexandre Mutel + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/wham_demonstrator/WHAMDemonstrator.dll b/wham_demonstrator/WHAMDemonstrator.dll new file mode 100644 index 0000000000000000000000000000000000000000..30eccd1978f731f537987e8966c404d723517183 Binary files /dev/null and b/wham_demonstrator/WHAMDemonstrator.dll differ diff --git a/wham_demonstrator/WHAMDemonstrator.exe b/wham_demonstrator/WHAMDemonstrator.exe new file mode 100644 index 0000000000000000000000000000000000000000..81826d4a03999120c20164149b791a2303852995 Binary files /dev/null and b/wham_demonstrator/WHAMDemonstrator.exe differ diff --git a/wham_demonstrator/WHAMDemonstrator.runtimeconfig.json b/wham_demonstrator/WHAMDemonstrator.runtimeconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..0b3b37138aeac3ce9a3f5b1f6d7acecf5b5b5cf1 --- /dev/null +++ b/wham_demonstrator/WHAMDemonstrator.runtimeconfig.json @@ -0,0 +1,19 @@ +{ + "runtimeOptions": { + "tfm": "net7.0", + "frameworks": [ + { + "name": "Microsoft.NETCore.App", + "version": "7.0.0" + }, + { + "name": "Microsoft.WindowsDesktop.App", + "version": "7.0.0" + } + ], + "configProperties": { + "System.Reflection.Metadata.MetadataUpdater.IsSupported": false, + "System.Runtime.Serialization.EnableUnsafeBinaryFormatterSerialization": true + } + } +} \ No newline at end of file