testing files upload (#7)
Browse files- testing files upload (f489a598b0d6a46d9a99c819210b220269d9b29b)
- .gitignore +32 -0
- LICENSE +399 -0
- MODELCARD.md +128 -0
- README.md +34 -120
- config.yaml +16 -0
- generate_reconstructions.ipynb +0 -0
- huggingface_mae.py +293 -0
- loss.py +59 -0
- mae_modules.py +273 -0
- mae_utils.py +70 -0
- masking.py +51 -0
- normalizer.py +7 -0
- pyproject.toml +34 -0
- sample/AA41_s1_1.jp2 +0 -0
- sample/AA41_s1_2.jp2 +0 -0
- sample/AA41_s1_3.jp2 +0 -0
- sample/AA41_s1_4.jp2 +0 -0
- sample/AA41_s1_5.jp2 +0 -0
- sample/AA41_s1_6.jp2 +0 -0
- test_huggingface_mae.py +32 -0
- vit.py +309 -0
- vit_encoder.py +61 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib/
         | 
| 18 | 
            +
            lib64/
         | 
| 19 | 
            +
            parts/
         | 
| 20 | 
            +
            sdist/
         | 
| 21 | 
            +
            var/
         | 
| 22 | 
            +
            wheels/
         | 
| 23 | 
            +
            share/python-wheels/
         | 
| 24 | 
            +
            *.egg-info/
         | 
| 25 | 
            +
            .installed.cfg
         | 
| 26 | 
            +
            *.egg
         | 
| 27 | 
            +
            MANIFEST
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # model artifacts
         | 
| 30 | 
            +
            *.pickle
         | 
| 31 | 
            +
            *.ckpt
         | 
| 32 | 
            +
            *.safetensors
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,399 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Attribution-NonCommercial 4.0 International
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            =======================================================================
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Creative Commons Corporation ("Creative Commons") is not a law firm and
         | 
| 6 | 
            +
            does not provide legal services or legal advice. Distribution of
         | 
| 7 | 
            +
            Creative Commons public licenses does not create a lawyer-client or
         | 
| 8 | 
            +
            other relationship. Creative Commons makes its licenses and related
         | 
| 9 | 
            +
            information available on an "as-is" basis. Creative Commons gives no
         | 
| 10 | 
            +
            warranties regarding its licenses, any material licensed under their
         | 
| 11 | 
            +
            terms and conditions, or any related information. Creative Commons
         | 
| 12 | 
            +
            disclaims all liability for damages resulting from their use to the
         | 
| 13 | 
            +
            fullest extent possible.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Using Creative Commons Public Licenses
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            Creative Commons public licenses provide a standard set of terms and
         | 
| 18 | 
            +
            conditions that creators and other rights holders may use to share
         | 
| 19 | 
            +
            original works of authorship and other material subject to copyright
         | 
| 20 | 
            +
            and certain other rights specified in the public license below. The
         | 
| 21 | 
            +
            following considerations are for informational purposes only, are not
         | 
| 22 | 
            +
            exhaustive, and do not form part of our licenses.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                 Considerations for licensors: Our public licenses are
         | 
| 25 | 
            +
                 intended for use by those authorized to give the public
         | 
| 26 | 
            +
                 permission to use material in ways otherwise restricted by
         | 
| 27 | 
            +
                 copyright and certain other rights. Our licenses are
         | 
| 28 | 
            +
                 irrevocable. Licensors should read and understand the terms
         | 
| 29 | 
            +
                 and conditions of the license they choose before applying it.
         | 
| 30 | 
            +
                 Licensors should also secure all rights necessary before
         | 
| 31 | 
            +
                 applying our licenses so that the public can reuse the
         | 
| 32 | 
            +
                 material as expected. Licensors should clearly mark any
         | 
| 33 | 
            +
                 material not subject to the license. This includes other CC-
         | 
| 34 | 
            +
                 licensed material, or material used under an exception or
         | 
| 35 | 
            +
                 limitation to copyright. More considerations for licensors:
         | 
| 36 | 
            +
            	wiki.creativecommons.org/Considerations_for_licensors
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                 Considerations for the public: By using one of our public
         | 
| 39 | 
            +
                 licenses, a licensor grants the public permission to use the
         | 
| 40 | 
            +
                 licensed material under specified terms and conditions. If
         | 
| 41 | 
            +
                 the licensor's permission is not necessary for any reason--for
         | 
| 42 | 
            +
                 example, because of any applicable exception or limitation to
         | 
| 43 | 
            +
                 copyright--then that use is not regulated by the license. Our
         | 
| 44 | 
            +
                 licenses grant only permissions under copyright and certain
         | 
| 45 | 
            +
                 other rights that a licensor has authority to grant. Use of
         | 
| 46 | 
            +
                 the licensed material may still be restricted for other
         | 
| 47 | 
            +
                 reasons, including because others have copyright or other
         | 
| 48 | 
            +
                 rights in the material. A licensor may make special requests,
         | 
| 49 | 
            +
                 such as asking that all changes be marked or described.
         | 
| 50 | 
            +
                 Although not required by our licenses, you are encouraged to
         | 
| 51 | 
            +
                 respect those requests where reasonable. More_considerations
         | 
| 52 | 
            +
                 for the public: 
         | 
| 53 | 
            +
            	wiki.creativecommons.org/Considerations_for_licensees
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            =======================================================================
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            Creative Commons Attribution-NonCommercial 4.0 International Public
         | 
| 58 | 
            +
            License
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            By exercising the Licensed Rights (defined below), You accept and agree
         | 
| 61 | 
            +
            to be bound by the terms and conditions of this Creative Commons
         | 
| 62 | 
            +
            Attribution-NonCommercial 4.0 International Public License ("Public
         | 
| 63 | 
            +
            License"). To the extent this Public License may be interpreted as a
         | 
| 64 | 
            +
            contract, You are granted the Licensed Rights in consideration of Your
         | 
| 65 | 
            +
            acceptance of these terms and conditions, and the Licensor grants You
         | 
| 66 | 
            +
            such rights in consideration of benefits the Licensor receives from
         | 
| 67 | 
            +
            making the Licensed Material available under these terms and
         | 
| 68 | 
            +
            conditions.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            Section 1 -- Definitions.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
              a. Adapted Material means material subject to Copyright and Similar
         | 
| 73 | 
            +
                 Rights that is derived from or based upon the Licensed Material
         | 
| 74 | 
            +
                 and in which the Licensed Material is translated, altered,
         | 
| 75 | 
            +
                 arranged, transformed, or otherwise modified in a manner requiring
         | 
| 76 | 
            +
                 permission under the Copyright and Similar Rights held by the
         | 
| 77 | 
            +
                 Licensor. For purposes of this Public License, where the Licensed
         | 
| 78 | 
            +
                 Material is a musical work, performance, or sound recording,
         | 
| 79 | 
            +
                 Adapted Material is always produced where the Licensed Material is
         | 
| 80 | 
            +
                 synched in timed relation with a moving image.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
              b. Adapter's License means the license You apply to Your Copyright
         | 
| 83 | 
            +
                 and Similar Rights in Your contributions to Adapted Material in
         | 
| 84 | 
            +
                 accordance with the terms and conditions of this Public License.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
              c. Copyright and Similar Rights means copyright and/or similar rights
         | 
| 87 | 
            +
                 closely related to copyright including, without limitation,
         | 
| 88 | 
            +
                 performance, broadcast, sound recording, and Sui Generis Database
         | 
| 89 | 
            +
                 Rights, without regard to how the rights are labeled or
         | 
| 90 | 
            +
                 categorized. For purposes of this Public License, the rights
         | 
| 91 | 
            +
                 specified in Section 2(b)(1)-(2) are not Copyright and Similar
         | 
| 92 | 
            +
                 Rights.
         | 
| 93 | 
            +
              d. Effective Technological Measures means those measures that, in the
         | 
| 94 | 
            +
                 absence of proper authority, may not be circumvented under laws
         | 
| 95 | 
            +
                 fulfilling obligations under Article 11 of the WIPO Copyright
         | 
| 96 | 
            +
                 Treaty adopted on December 20, 1996, and/or similar international
         | 
| 97 | 
            +
                 agreements.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
              e. Exceptions and Limitations means fair use, fair dealing, and/or
         | 
| 100 | 
            +
                 any other exception or limitation to Copyright and Similar Rights
         | 
| 101 | 
            +
                 that applies to Your use of the Licensed Material.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
              f. Licensed Material means the artistic or literary work, database,
         | 
| 104 | 
            +
                 or other material to which the Licensor applied this Public
         | 
| 105 | 
            +
                 License.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
              g. Licensed Rights means the rights granted to You subject to the
         | 
| 108 | 
            +
                 terms and conditions of this Public License, which are limited to
         | 
| 109 | 
            +
                 all Copyright and Similar Rights that apply to Your use of the
         | 
| 110 | 
            +
                 Licensed Material and that the Licensor has authority to license.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
              h. Licensor means the individual(s) or entity(ies) granting rights
         | 
| 113 | 
            +
                 under this Public License.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              i. NonCommercial means not primarily intended for or directed towards
         | 
| 116 | 
            +
                 commercial advantage or monetary compensation. For purposes of
         | 
| 117 | 
            +
                 this Public License, the exchange of the Licensed Material for
         | 
| 118 | 
            +
                 other material subject to Copyright and Similar Rights by digital
         | 
| 119 | 
            +
                 file-sharing or similar means is NonCommercial provided there is
         | 
| 120 | 
            +
                 no payment of monetary compensation in connection with the
         | 
| 121 | 
            +
                 exchange.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
              j. Share means to provide material to the public by any means or
         | 
| 124 | 
            +
                 process that requires permission under the Licensed Rights, such
         | 
| 125 | 
            +
                 as reproduction, public display, public performance, distribution,
         | 
| 126 | 
            +
                 dissemination, communication, or importation, and to make material
         | 
| 127 | 
            +
                 available to the public including in ways that members of the
         | 
| 128 | 
            +
                 public may access the material from a place and at a time
         | 
| 129 | 
            +
                 individually chosen by them.
         | 
| 130 | 
            +
             | 
| 131 | 
            +
              k. Sui Generis Database Rights means rights other than copyright
         | 
| 132 | 
            +
                 resulting from Directive 96/9/EC of the European Parliament and of
         | 
| 133 | 
            +
                 the Council of 11 March 1996 on the legal protection of databases,
         | 
| 134 | 
            +
                 as amended and/or succeeded, as well as other essentially
         | 
| 135 | 
            +
                 equivalent rights anywhere in the world.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
              l. You means the individual or entity exercising the Licensed Rights
         | 
| 138 | 
            +
                 under this Public License. Your has a corresponding meaning.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            Section 2 -- Scope.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
              a. License grant.
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                   1. Subject to the terms and conditions of this Public License,
         | 
| 145 | 
            +
                      the Licensor hereby grants You a worldwide, royalty-free,
         | 
| 146 | 
            +
                      non-sublicensable, non-exclusive, irrevocable license to
         | 
| 147 | 
            +
                      exercise the Licensed Rights in the Licensed Material to:
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        a. reproduce and Share the Licensed Material, in whole or
         | 
| 150 | 
            +
                           in part, for NonCommercial purposes only; and
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                        b. produce, reproduce, and Share Adapted Material for
         | 
| 153 | 
            +
                           NonCommercial purposes only.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                   2. Exceptions and Limitations. For the avoidance of doubt, where
         | 
| 156 | 
            +
                      Exceptions and Limitations apply to Your use, this Public
         | 
| 157 | 
            +
                      License does not apply, and You do not need to comply with
         | 
| 158 | 
            +
                      its terms and conditions.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                   3. Term. The term of this Public License is specified in Section
         | 
| 161 | 
            +
                      6(a).
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                   4. Media and formats; technical modifications allowed. The
         | 
| 164 | 
            +
                      Licensor authorizes You to exercise the Licensed Rights in
         | 
| 165 | 
            +
                      all media and formats whether now known or hereafter created,
         | 
| 166 | 
            +
                      and to make technical modifications necessary to do so. The
         | 
| 167 | 
            +
                      Licensor waives and/or agrees not to assert any right or
         | 
| 168 | 
            +
                      authority to forbid You from making technical modifications
         | 
| 169 | 
            +
                      necessary to exercise the Licensed Rights, including
         | 
| 170 | 
            +
                      technical modifications necessary to circumvent Effective
         | 
| 171 | 
            +
                      Technological Measures. For purposes of this Public License,
         | 
| 172 | 
            +
                      simply making modifications authorized by this Section 2(a)
         | 
| 173 | 
            +
                      (4) never produces Adapted Material.
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                   5. Downstream recipients.
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                        a. Offer from the Licensor -- Licensed Material. Every
         | 
| 178 | 
            +
                           recipient of the Licensed Material automatically
         | 
| 179 | 
            +
                           receives an offer from the Licensor to exercise the
         | 
| 180 | 
            +
                           Licensed Rights under the terms and conditions of this
         | 
| 181 | 
            +
                           Public License.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        b. No downstream restrictions. You may not offer or impose
         | 
| 184 | 
            +
                           any additional or different terms or conditions on, or
         | 
| 185 | 
            +
                           apply any Effective Technological Measures to, the
         | 
| 186 | 
            +
                           Licensed Material if doing so restricts exercise of the
         | 
| 187 | 
            +
                           Licensed Rights by any recipient of the Licensed
         | 
| 188 | 
            +
                           Material.
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                   6. No endorsement. Nothing in this Public License constitutes or
         | 
| 191 | 
            +
                      may be construed as permission to assert or imply that You
         | 
| 192 | 
            +
                      are, or that Your use of the Licensed Material is, connected
         | 
| 193 | 
            +
                      with, or sponsored, endorsed, or granted official status by,
         | 
| 194 | 
            +
                      the Licensor or others designated to receive attribution as
         | 
| 195 | 
            +
                      provided in Section 3(a)(1)(A)(i).
         | 
| 196 | 
            +
             | 
| 197 | 
            +
              b. Other rights.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                   1. Moral rights, such as the right of integrity, are not
         | 
| 200 | 
            +
                      licensed under this Public License, nor are publicity,
         | 
| 201 | 
            +
                      privacy, and/or other similar personality rights; however, to
         | 
| 202 | 
            +
                      the extent possible, the Licensor waives and/or agrees not to
         | 
| 203 | 
            +
                      assert any such rights held by the Licensor to the limited
         | 
| 204 | 
            +
                      extent necessary to allow You to exercise the Licensed
         | 
| 205 | 
            +
                      Rights, but not otherwise.
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                   2. Patent and trademark rights are not licensed under this
         | 
| 208 | 
            +
                      Public License.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                   3. To the extent possible, the Licensor waives any right to
         | 
| 211 | 
            +
                      collect royalties from You for the exercise of the Licensed
         | 
| 212 | 
            +
                      Rights, whether directly or through a collecting society
         | 
| 213 | 
            +
                      under any voluntary or waivable statutory or compulsory
         | 
| 214 | 
            +
                      licensing scheme. In all other cases the Licensor expressly
         | 
| 215 | 
            +
                      reserves any right to collect such royalties, including when
         | 
| 216 | 
            +
                      the Licensed Material is used other than for NonCommercial
         | 
| 217 | 
            +
                      purposes.
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            Section 3 -- License Conditions.
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            Your exercise of the Licensed Rights is expressly made subject to the
         | 
| 222 | 
            +
            following conditions.
         | 
| 223 | 
            +
             | 
| 224 | 
            +
              a. Attribution.
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                   1. If You Share the Licensed Material (including in modified
         | 
| 227 | 
            +
                      form), You must:
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        a. retain the following if it is supplied by the Licensor
         | 
| 230 | 
            +
                           with the Licensed Material:
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                             i. identification of the creator(s) of the Licensed
         | 
| 233 | 
            +
                                Material and any others designated to receive
         | 
| 234 | 
            +
                                attribution, in any reasonable manner requested by
         | 
| 235 | 
            +
                                the Licensor (including by pseudonym if
         | 
| 236 | 
            +
                                designated);
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                            ii. a copyright notice;
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                           iii. a notice that refers to this Public License;
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                            iv. a notice that refers to the disclaimer of
         | 
| 243 | 
            +
                                warranties;
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                             v. a URI or hyperlink to the Licensed Material to the
         | 
| 246 | 
            +
                                extent reasonably practicable;
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        b. indicate if You modified the Licensed Material and
         | 
| 249 | 
            +
                           retain an indication of any previous modifications; and
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        c. indicate the Licensed Material is licensed under this
         | 
| 252 | 
            +
                           Public License, and include the text of, or the URI or
         | 
| 253 | 
            +
                           hyperlink to, this Public License.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                   2. You may satisfy the conditions in Section 3(a)(1) in any
         | 
| 256 | 
            +
                      reasonable manner based on the medium, means, and context in
         | 
| 257 | 
            +
                      which You Share the Licensed Material. For example, it may be
         | 
| 258 | 
            +
                      reasonable to satisfy the conditions by providing a URI or
         | 
| 259 | 
            +
                      hyperlink to a resource that includes the required
         | 
| 260 | 
            +
                      information.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                   3. If requested by the Licensor, You must remove any of the
         | 
| 263 | 
            +
                      information required by Section 3(a)(1)(A) to the extent
         | 
| 264 | 
            +
                      reasonably practicable.
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                   4. If You Share Adapted Material You produce, the Adapter's
         | 
| 267 | 
            +
                      License You apply must not prevent recipients of the Adapted
         | 
| 268 | 
            +
                      Material from complying with this Public License.
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            Section 4 -- Sui Generis Database Rights.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
            Where the Licensed Rights include Sui Generis Database Rights that
         | 
| 273 | 
            +
            apply to Your use of the Licensed Material:
         | 
| 274 | 
            +
             | 
| 275 | 
            +
              a. for the avoidance of doubt, Section 2(a)(1) grants You the right
         | 
| 276 | 
            +
                 to extract, reuse, reproduce, and Share all or a substantial
         | 
| 277 | 
            +
                 portion of the contents of the database for NonCommercial purposes
         | 
| 278 | 
            +
                 only;
         | 
| 279 | 
            +
             | 
| 280 | 
            +
              b. if You include all or a substantial portion of the database
         | 
| 281 | 
            +
                 contents in a database in which You have Sui Generis Database
         | 
| 282 | 
            +
                 Rights, then the database in which You have Sui Generis Database
         | 
| 283 | 
            +
                 Rights (but not its individual contents) is Adapted Material; and
         | 
| 284 | 
            +
             | 
| 285 | 
            +
              c. You must comply with the conditions in Section 3(a) if You Share
         | 
| 286 | 
            +
                 all or a substantial portion of the contents of the database.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            For the avoidance of doubt, this Section 4 supplements and does not
         | 
| 289 | 
            +
            replace Your obligations under this Public License where the Licensed
         | 
| 290 | 
            +
            Rights include other Copyright and Similar Rights.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
            Section 5 -- Disclaimer of Warranties and Limitation of Liability.
         | 
| 293 | 
            +
             | 
| 294 | 
            +
              a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
         | 
| 295 | 
            +
                 EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
         | 
| 296 | 
            +
                 AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
         | 
| 297 | 
            +
                 ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
         | 
| 298 | 
            +
                 IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
         | 
| 299 | 
            +
                 WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
         | 
| 300 | 
            +
                 PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
         | 
| 301 | 
            +
                 ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
         | 
| 302 | 
            +
                 KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
         | 
| 303 | 
            +
                 ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
              b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
         | 
| 306 | 
            +
                 TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
         | 
| 307 | 
            +
                 NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
         | 
| 308 | 
            +
                 INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
         | 
| 309 | 
            +
                 COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
         | 
| 310 | 
            +
                 USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
         | 
| 311 | 
            +
                 ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
         | 
| 312 | 
            +
                 DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
         | 
| 313 | 
            +
                 IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
         | 
| 314 | 
            +
             | 
| 315 | 
            +
              c. The disclaimer of warranties and limitation of liability provided
         | 
| 316 | 
            +
                 above shall be interpreted in a manner that, to the extent
         | 
| 317 | 
            +
                 possible, most closely approximates an absolute disclaimer and
         | 
| 318 | 
            +
                 waiver of all liability.
         | 
| 319 | 
            +
             | 
| 320 | 
            +
            Section 6 -- Term and Termination.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
              a. This Public License applies for the term of the Copyright and
         | 
| 323 | 
            +
                 Similar Rights licensed here. However, if You fail to comply with
         | 
| 324 | 
            +
                 this Public License, then Your rights under this Public License
         | 
| 325 | 
            +
                 terminate automatically.
         | 
| 326 | 
            +
             | 
| 327 | 
            +
              b. Where Your right to use the Licensed Material has terminated under
         | 
| 328 | 
            +
                 Section 6(a), it reinstates:
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                   1. automatically as of the date the violation is cured, provided
         | 
| 331 | 
            +
                      it is cured within 30 days of Your discovery of the
         | 
| 332 | 
            +
                      violation; or
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                   2. upon express reinstatement by the Licensor.
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                 For the avoidance of doubt, this Section 6(b) does not affect any
         | 
| 337 | 
            +
                 right the Licensor may have to seek remedies for Your violations
         | 
| 338 | 
            +
                 of this Public License.
         | 
| 339 | 
            +
             | 
| 340 | 
            +
              c. For the avoidance of doubt, the Licensor may also offer the
         | 
| 341 | 
            +
                 Licensed Material under separate terms or conditions or stop
         | 
| 342 | 
            +
                 distributing the Licensed Material at any time; however, doing so
         | 
| 343 | 
            +
                 will not terminate this Public License.
         | 
| 344 | 
            +
             | 
| 345 | 
            +
              d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
         | 
| 346 | 
            +
                 License.
         | 
| 347 | 
            +
             | 
| 348 | 
            +
            Section 7 -- Other Terms and Conditions.
         | 
| 349 | 
            +
             | 
| 350 | 
            +
              a. The Licensor shall not be bound by any additional or different
         | 
| 351 | 
            +
                 terms or conditions communicated by You unless expressly agreed.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
              b. Any arrangements, understandings, or agreements regarding the
         | 
| 354 | 
            +
                 Licensed Material not stated herein are separate from and
         | 
| 355 | 
            +
                 independent of the terms and conditions of this Public License.
         | 
| 356 | 
            +
             | 
| 357 | 
            +
            Section 8 -- Interpretation.
         | 
| 358 | 
            +
             | 
| 359 | 
            +
              a. For the avoidance of doubt, this Public License does not, and
         | 
| 360 | 
            +
                 shall not be interpreted to, reduce, limit, restrict, or impose
         | 
| 361 | 
            +
                 conditions on any use of the Licensed Material that could lawfully
         | 
| 362 | 
            +
                 be made without permission under this Public License.
         | 
| 363 | 
            +
             | 
| 364 | 
            +
              b. To the extent possible, if any provision of this Public License is
         | 
| 365 | 
            +
                 deemed unenforceable, it shall be automatically reformed to the
         | 
| 366 | 
            +
                 minimum extent necessary to make it enforceable. If the provision
         | 
| 367 | 
            +
                 cannot be reformed, it shall be severed from this Public License
         | 
| 368 | 
            +
                 without affecting the enforceability of the remaining terms and
         | 
| 369 | 
            +
                 conditions.
         | 
| 370 | 
            +
             | 
| 371 | 
            +
              c. No term or condition of this Public License will be waived and no
         | 
| 372 | 
            +
                 failure to comply consented to unless expressly agreed to by the
         | 
| 373 | 
            +
                 Licensor.
         | 
| 374 | 
            +
             | 
| 375 | 
            +
              d. Nothing in this Public License constitutes or may be interpreted
         | 
| 376 | 
            +
                 as a limitation upon, or waiver of, any privileges and immunities
         | 
| 377 | 
            +
                 that apply to the Licensor or You, including from the legal
         | 
| 378 | 
            +
                 processes of any jurisdiction or authority.
         | 
| 379 | 
            +
             | 
| 380 | 
            +
            =======================================================================
         | 
| 381 | 
            +
             | 
| 382 | 
            +
            Creative Commons is not a party to its public
         | 
| 383 | 
            +
            licenses. Notwithstanding, Creative Commons may elect to apply one of
         | 
| 384 | 
            +
            its public licenses to material it publishes and in those instances
         | 
| 385 | 
            +
            will be considered the “Licensor.” The text of the Creative Commons
         | 
| 386 | 
            +
            public licenses is dedicated to the public domain under the CC0 Public
         | 
| 387 | 
            +
            Domain Dedication. Except for the limited purpose of indicating that
         | 
| 388 | 
            +
            material is shared under a Creative Commons public license or as
         | 
| 389 | 
            +
            otherwise permitted by the Creative Commons policies published at
         | 
| 390 | 
            +
            creativecommons.org/policies, Creative Commons does not authorize the
         | 
| 391 | 
            +
            use of the trademark "Creative Commons" or any other trademark or logo
         | 
| 392 | 
            +
            of Creative Commons without its prior written consent including,
         | 
| 393 | 
            +
            without limitation, in connection with any unauthorized modifications
         | 
| 394 | 
            +
            to any of its public licenses or any other arrangements,
         | 
| 395 | 
            +
            understandings, or agreements concerning use of licensed material. For
         | 
| 396 | 
            +
            the avoidance of doubt, this paragraph does not form part of the
         | 
| 397 | 
            +
            public licenses.
         | 
| 398 | 
            +
             | 
| 399 | 
            +
            Creative Commons may be contacted at creativecommons.org.
         | 
    	
        MODELCARD.md
    ADDED
    
    | @@ -0,0 +1,128 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            library_name: transformers
         | 
| 3 | 
            +
            tags: []
         | 
| 4 | 
            +
            ---
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Model Card for Phenom CA-MAE-S/16
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            Channel-agnostic image encoding model designed for microscopy image featurization. 
         | 
| 9 | 
            +
            The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            ## Model Details
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ### Model Description
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
         | 
| 17 | 
            +
            1. RxRx3
         | 
| 18 | 
            +
            2. JUMP-CP overexpression
         | 
| 19 | 
            +
            3. JUMP-CP gene-knockouts
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            - **Developed, funded, and shared by:** Recursion
         | 
| 22 | 
            +
            - **Model type:** Vision transformer CA-MAE
         | 
| 23 | 
            +
            - **Image modality:** Optimized for microscopy images from the CellPainting assay
         | 
| 24 | 
            +
            - **License:** 
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            ### Model Sources
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
         | 
| 30 | 
            +
            - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            ## Uses
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
         | 
| 38 | 
            +
            2. Transform all the embeddings with that PCA kernel,
         | 
| 39 | 
            +
            3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            ### Direct Use
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            - Create biologically useful embeddings of microscopy images
         | 
| 44 | 
            +
            - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
         | 
| 45 | 
            +
            - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            ### Downstream Use
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            - A determined ML expert could fine-tune the encoder for downstream tasks such as classification
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ### Out-of-Scope Use
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            - Unlikely to be especially performant on brightfield microscopy images
         | 
| 54 | 
            +
            - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            ## Bias, Risks, and Limitations
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            ## How to Get Started with the Model
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            ```python
         | 
| 65 | 
            +
            import pytest
         | 
| 66 | 
            +
            import torch
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            from huggingface_mae import MAEModel
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            huggingface_phenombeta_model_dir = "."
         | 
| 71 | 
            +
            # huggingface_modelpath = "recursionpharma/test-pb-model"
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            @pytest.fixture
         | 
| 75 | 
            +
            def huggingface_model():
         | 
| 76 | 
            +
                # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
         | 
| 77 | 
            +
                # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
         | 
| 78 | 
            +
                huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
         | 
| 79 | 
            +
                huggingface_model.eval()
         | 
| 80 | 
            +
                return huggingface_model
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            @pytest.mark.parametrize("C", [1, 4, 6, 11])
         | 
| 84 | 
            +
            @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
         | 
| 85 | 
            +
            def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
         | 
| 86 | 
            +
                example_input_array = torch.randint(
         | 
| 87 | 
            +
                    low=0,
         | 
| 88 | 
            +
                    high=255,
         | 
| 89 | 
            +
                    size=(2, C, 256, 256),
         | 
| 90 | 
            +
                    dtype=torch.uint8,
         | 
| 91 | 
            +
                    device=huggingface_model.device,
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
         | 
| 94 | 
            +
                embeddings = huggingface_model.predict(example_input_array)
         | 
| 95 | 
            +
                expected_output_dim = 384 * C if return_channelwise_embeddings else 384
         | 
| 96 | 
            +
                assert embeddings.shape == (2, expected_output_dim)
         | 
| 97 | 
            +
            ```
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            ## Training, evaluation and testing details
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            ## Environmental Impact
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            - **Hardware Type:** Nvidia H100 Hopper nodes
         | 
| 108 | 
            +
            - **Hours used:** 400
         | 
| 109 | 
            +
            - **Cloud Provider:** private cloud
         | 
| 110 | 
            +
            - **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            **BibTeX:**
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            ```TeX
         | 
| 115 | 
            +
            @inproceedings{kraus2024masked,
         | 
| 116 | 
            +
              title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
         | 
| 117 | 
            +
              author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
         | 
| 118 | 
            +
              booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
         | 
| 119 | 
            +
              pages={11757--11768},
         | 
| 120 | 
            +
              year={2024}
         | 
| 121 | 
            +
            }
         | 
| 122 | 
            +
            ```
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            ## Model Card Contact
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            - Kian Kenyon-Dean: [email protected]
         | 
| 127 | 
            +
            - Oren Kraus: [email protected]
         | 
| 128 | 
            +
            - Or, email: [email protected]
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,128 +1,42 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 |  | 
| 8 | 
            -
            Channel-agnostic image encoding model designed for microscopy image featurization. 
         | 
| 9 | 
            -
            The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
         | 
| 10 |  | 
|  | |
|  | |
| 11 |  | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            ### Model Description
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
         | 
| 17 | 
            -
            1. RxRx3
         | 
| 18 | 
            -
            2. JUMP-CP overexpression
         | 
| 19 | 
            -
            3. JUMP-CP gene-knockouts
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            - **Developed, funded, and shared by:** Recursion
         | 
| 22 | 
            -
            - **Model type:** Vision transformer CA-MAE
         | 
| 23 | 
            -
            - **Image modality:** Optimized for microscopy images from the CellPainting assay
         | 
| 24 | 
            -
            - **License:** 
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            ### Model Sources
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
         | 
| 30 | 
            -
            - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
         | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
            ## Uses
         | 
| 34 | 
            -
             | 
| 35 | 
            -
            NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
         | 
| 38 | 
            -
            2. Transform all the embeddings with that PCA kernel,
         | 
| 39 | 
            -
            3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
         | 
| 40 | 
            -
             | 
| 41 | 
            -
            ### Direct Use
         | 
| 42 | 
            -
             | 
| 43 | 
            -
            - Create biologically useful embeddings of microscopy images
         | 
| 44 | 
            -
            - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
         | 
| 45 | 
            -
            - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            ### Downstream Use
         | 
| 48 | 
            -
             | 
| 49 | 
            -
            - A determined ML expert could fine-tune the encoder for downstream tasks such as classification
         | 
| 50 | 
            -
             | 
| 51 | 
            -
            ### Out-of-Scope Use
         | 
| 52 | 
            -
             | 
| 53 | 
            -
            - Unlikely to be especially performant on brightfield microscopy images
         | 
| 54 | 
            -
            - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
         | 
| 55 | 
            -
             | 
| 56 | 
            -
            ## Bias, Risks, and Limitations
         | 
| 57 | 
            -
             | 
| 58 | 
            -
            - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
         | 
| 59 | 
            -
             | 
| 60 | 
            -
            ## How to Get Started with the Model
         | 
| 61 | 
            -
             | 
| 62 | 
            -
            You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
         | 
| 63 | 
            -
             | 
| 64 | 
            -
            ```python
         | 
| 65 | 
            -
            import pytest
         | 
| 66 | 
            -
            import torch
         | 
| 67 | 
            -
             | 
| 68 | 
            -
            from huggingface_mae import MAEModel
         | 
| 69 | 
            -
             | 
| 70 | 
            -
            huggingface_phenombeta_model_dir = "models/phenom_beta_huggingface"
         | 
| 71 | 
            -
            # huggingface_modelpath = "recursionpharma/test-pb-model"
         | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
            @pytest.fixture
         | 
| 75 | 
            -
            def huggingface_model():
         | 
| 76 | 
            -
                # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
         | 
| 77 | 
            -
                # huggingface-cli download recursionpharma/test-pb-model --local-dir=models/phenom_beta_huggingface
         | 
| 78 | 
            -
                huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
         | 
| 79 | 
            -
                huggingface_model.eval()
         | 
| 80 | 
            -
                return huggingface_model
         | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
            @pytest.mark.parametrize("C", [1, 4, 6, 11])
         | 
| 84 | 
            -
            @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
         | 
| 85 | 
            -
            def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
         | 
| 86 | 
            -
                example_input_array = torch.randint(
         | 
| 87 | 
            -
                    low=0,
         | 
| 88 | 
            -
                    high=255,
         | 
| 89 | 
            -
                    size=(2, C, 256, 256),
         | 
| 90 | 
            -
                    dtype=torch.uint8,
         | 
| 91 | 
            -
                    device=huggingface_model.device,
         | 
| 92 | 
            -
                )
         | 
| 93 | 
            -
                huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
         | 
| 94 | 
            -
                embeddings = huggingface_model.predict(example_input_array)
         | 
| 95 | 
            -
                expected_output_dim = 384 * C if return_channelwise_embeddings else 384
         | 
| 96 | 
            -
                assert embeddings.shape == (2, expected_output_dim)
         | 
| 97 | 
             
            ```
         | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
              title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
         | 
| 117 | 
            -
              author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
         | 
| 118 | 
            -
              booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
         | 
| 119 | 
            -
              pages={11757--11768},
         | 
| 120 | 
            -
              year={2024}
         | 
| 121 | 
            -
            }
         | 
| 122 | 
             
            ```
         | 
| 123 |  | 
| 124 | 
            -
            ##  | 
|  | |
| 125 |  | 
| 126 | 
            -
            -  | 
| 127 | 
            -
            -  | 
| 128 | 
            -
            -  | 
|  | |
| 1 | 
            +
            # Masked Autoencoders are Scalable Learners of Cellular Morphology
         | 
| 2 | 
            +
            Official repo for Recursion's two recently accepted papers:
         | 
| 3 | 
            +
            - Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology
         | 
| 4 | 
            +
              - Paper: https://arxiv.org/abs/2404.10242
         | 
| 5 | 
            +
              - CVPR poster page with video: https://cvpr.thecvf.com/virtual/2024/poster/31565
         | 
| 6 | 
            +
            - Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio)
         | 
| 7 | 
            +
              - Paper: https://arxiv.org/abs/2309.16064
         | 
| 8 |  | 
| 9 | 
            +
            
         | 
| 10 |  | 
|  | |
|  | |
| 11 |  | 
| 12 | 
            +
            ## Provided code
         | 
| 13 | 
            +
            See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase.
         | 
| 14 |  | 
| 15 | 
            +
            Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 16 | 
             
            ```
         | 
| 17 | 
            +
            import timm.models.vision_transformer as vit
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def vit_base_patch16_256(**kwargs):
         | 
| 20 | 
            +
                default_kwargs = dict(
         | 
| 21 | 
            +
                    img_size=256,
         | 
| 22 | 
            +
                    in_chans=6,
         | 
| 23 | 
            +
                    num_classes=0,
         | 
| 24 | 
            +
                    fc_norm=None,
         | 
| 25 | 
            +
                    class_token=True,
         | 
| 26 | 
            +
                    drop_path_rate=0.1,
         | 
| 27 | 
            +
                    init_values=0.0001,
         | 
| 28 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 29 | 
            +
                    qkv_bias=False,
         | 
| 30 | 
            +
                    qk_norm=True,
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                for k, v in kwargs.items():
         | 
| 33 | 
            +
                    default_kwargs[k] = v
         | 
| 34 | 
            +
                return vit.vit_base_patch16_224(**default_kwargs)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 | 
             
            ```
         | 
| 36 |  | 
| 37 | 
            +
            ## Provided models
         | 
| 38 | 
            +
            A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling: https://www.rxrx.ai/phenom
         | 
| 39 |  | 
| 40 | 
            +
            We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
         | 
| 41 | 
            +
            - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
         | 
| 42 | 
            +
            - https://www.youtube.com/watch?v=Gch6bX1toB0
         | 
    	
        config.yaml
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            loss:
         | 
| 3 | 
            +
              _target_: torch.nn.MSELoss  # combine with fourier loss weighted at 0.01 mixing factor for best results
         | 
| 4 | 
            +
              reduction: none
         | 
| 5 | 
            +
            optimizer:
         | 
| 6 | 
            +
              _target_: timm.optim.lion.Lion
         | 
| 7 | 
            +
              _partial_: true
         | 
| 8 | 
            +
              lr: *lr 1e-4   # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
         | 
| 9 | 
            +
              weight_decay: 0.05
         | 
| 10 | 
            +
              betas: [0.9, 0.95]
         | 
| 11 | 
            +
            lr_scheduler:
         | 
| 12 | 
            +
              _target_: torch.optim.lr_scheduler.OneCycleLR
         | 
| 13 | 
            +
              _partial_: true
         | 
| 14 | 
            +
              max_lr: @lr
         | 
| 15 | 
            +
              pct_start: 0.1
         | 
| 16 | 
            +
              anneal_strategy: cos
         | 
    	
        generate_reconstructions.ipynb
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        huggingface_mae.py
    ADDED
    
    | @@ -0,0 +1,293 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from transformers import PretrainedConfig, PreTrainedModel
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from loss import FourierLoss
         | 
| 9 | 
            +
            from normalizer import Normalizer
         | 
| 10 | 
            +
            from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
         | 
| 11 | 
            +
            from mae_utils import flatten_images
         | 
| 12 | 
            +
            from vit import (
         | 
| 13 | 
            +
                generate_2d_sincos_pos_embeddings,
         | 
| 14 | 
            +
                sincos_positional_encoding_vit,
         | 
| 15 | 
            +
                vit_small_patch16_256,
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            TensorDict = Dict[str, torch.Tensor]
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class MAEConfig(PretrainedConfig):
         | 
| 22 | 
            +
                model_type = "MAE"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    mask_ratio=0.75,
         | 
| 27 | 
            +
                    encoder=None,
         | 
| 28 | 
            +
                    decoder=None,
         | 
| 29 | 
            +
                    loss=None,
         | 
| 30 | 
            +
                    optimizer=None,
         | 
| 31 | 
            +
                    input_norm=None,
         | 
| 32 | 
            +
                    fourier_loss=None,
         | 
| 33 | 
            +
                    fourier_loss_weight=0.0,
         | 
| 34 | 
            +
                    lr_scheduler=None,
         | 
| 35 | 
            +
                    use_MAE_weight_init=False,
         | 
| 36 | 
            +
                    crop_size=-1,
         | 
| 37 | 
            +
                    mask_fourier_loss=True,
         | 
| 38 | 
            +
                    return_channelwise_embeddings=False,
         | 
| 39 | 
            +
                    **kwargs,
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    super().__init__(**kwargs)
         | 
| 42 | 
            +
                    self.mask_ratio = mask_ratio
         | 
| 43 | 
            +
                    self.encoder = encoder
         | 
| 44 | 
            +
                    self.decoder = decoder
         | 
| 45 | 
            +
                    self.loss = loss
         | 
| 46 | 
            +
                    self.optimizer = optimizer
         | 
| 47 | 
            +
                    self.input_norm = input_norm
         | 
| 48 | 
            +
                    self.fourier_loss = fourier_loss
         | 
| 49 | 
            +
                    self.fourier_loss_weight = fourier_loss_weight
         | 
| 50 | 
            +
                    self.lr_scheduler = lr_scheduler
         | 
| 51 | 
            +
                    self.use_MAE_weight_init = use_MAE_weight_init
         | 
| 52 | 
            +
                    self.crop_size = crop_size
         | 
| 53 | 
            +
                    self.mask_fourier_loss = mask_fourier_loss
         | 
| 54 | 
            +
                    self.return_channelwise_embeddings = return_channelwise_embeddings
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class MAEModel(PreTrainedModel):
         | 
| 58 | 
            +
                config_class = MAEConfig
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # Loss metrics
         | 
| 61 | 
            +
                TOTAL_LOSS = "loss"
         | 
| 62 | 
            +
                RECON_LOSS = "reconstruction_loss"
         | 
| 63 | 
            +
                FOURIER_LOSS = "fourier_loss"
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __init__(self, config: MAEConfig):
         | 
| 66 | 
            +
                    super().__init__(config)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.mask_ratio = config.mask_ratio
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # Could use Hydra to instantiate instead
         | 
| 71 | 
            +
                    self.encoder = MAEEncoder(
         | 
| 72 | 
            +
                        vit_backbone=sincos_positional_encoding_vit(
         | 
| 73 | 
            +
                            vit_backbone=vit_small_patch16_256(global_pool="avg")
         | 
| 74 | 
            +
                        ),
         | 
| 75 | 
            +
                        max_in_chans=11,  # upper limit on number of input channels
         | 
| 76 | 
            +
                        channel_agnostic=True,
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    self.decoder = CAMAEDecoder(
         | 
| 79 | 
            +
                        depth=8,
         | 
| 80 | 
            +
                        embed_dim=512,
         | 
| 81 | 
            +
                        mlp_ratio=4,
         | 
| 82 | 
            +
                        norm_layer=nn.LayerNorm,
         | 
| 83 | 
            +
                        num_heads=16,
         | 
| 84 | 
            +
                        num_modalities=6,
         | 
| 85 | 
            +
                        qkv_bias=True,
         | 
| 86 | 
            +
                        tokens_per_modality=256,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    self.input_norm = torch.nn.Sequential(
         | 
| 89 | 
            +
                        Normalizer(),
         | 
| 90 | 
            +
                        nn.InstanceNorm2d(None, affine=False, track_running_stats=False),
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.fourier_loss_weight = config.fourier_loss_weight
         | 
| 94 | 
            +
                    self.mask_fourier_loss = config.mask_fourier_loss
         | 
| 95 | 
            +
                    self.return_channelwise_embeddings = config.return_channelwise_embeddings
         | 
| 96 | 
            +
                    self.tokens_per_channel = 256  # hardcode the number of tokens per channel since we are patch16 crop 256
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # loss stuff
         | 
| 99 | 
            +
                    self.loss = torch.nn.MSELoss(reduction="none")
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self.fourier_loss = FourierLoss(num_multimodal_modalities=6)
         | 
| 102 | 
            +
                    if self.fourier_loss_weight > 0 and self.fourier_loss is None:
         | 
| 103 | 
            +
                        raise ValueError(
         | 
| 104 | 
            +
                            "FourierLoss weight is activated but no fourier_loss was defined in constructor"
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                    elif self.fourier_loss_weight >= 1:
         | 
| 107 | 
            +
                        raise ValueError(
         | 
| 108 | 
            +
                            "FourierLoss weight is too large to do mixing factor, weight should be < 1"
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0])
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # projection layer between the encoder and decoder
         | 
| 114 | 
            +
                    self.encoder_decoder_proj = nn.Linear(
         | 
| 115 | 
            +
                        self.encoder.embed_dim, self.decoder.embed_dim, bias=True
         | 
| 116 | 
            +
                    )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.decoder_pred = nn.Linear(
         | 
| 119 | 
            +
                        self.decoder.embed_dim,
         | 
| 120 | 
            +
                        self.patch_size**2
         | 
| 121 | 
            +
                        * (1 if self.encoder.channel_agnostic else self.in_chans),
         | 
| 122 | 
            +
                        bias=True,
         | 
| 123 | 
            +
                    )  # linear layer from decoder embedding to input dims
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # overwrite decoder pos embeddings based on encoder params
         | 
| 126 | 
            +
                    self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings(  # type: ignore[assignment]
         | 
| 127 | 
            +
                        self.decoder.embed_dim,
         | 
| 128 | 
            +
                        length=self.encoder.vit_backbone.patch_embed.grid_size[0],
         | 
| 129 | 
            +
                        use_class_token=self.encoder.vit_backbone.cls_token is not None,
         | 
| 130 | 
            +
                        num_modality=(
         | 
| 131 | 
            +
                            self.decoder.num_modalities if self.encoder.channel_agnostic else 1
         | 
| 132 | 
            +
                        ),
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    if config.use_MAE_weight_init:
         | 
| 136 | 
            +
                        w = self.encoder.vit_backbone.patch_embed.proj.weight.data
         | 
| 137 | 
            +
                        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02)
         | 
| 140 | 
            +
                        torch.nn.init.normal_(self.decoder.mask_token, std=0.02)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        self.apply(self._MAE_init_weights)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def setup(self, stage: str) -> None:
         | 
| 145 | 
            +
                    super().setup(stage)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def _MAE_init_weights(self, m):
         | 
| 148 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 149 | 
            +
                        torch.nn.init.xavier_uniform_(m.weight)
         | 
| 150 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 151 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 152 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 153 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 154 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                @staticmethod
         | 
| 157 | 
            +
                def decode_to_reconstruction(
         | 
| 158 | 
            +
                    encoder_latent: torch.Tensor,
         | 
| 159 | 
            +
                    ind_restore: torch.Tensor,
         | 
| 160 | 
            +
                    proj: torch.nn.Module,
         | 
| 161 | 
            +
                    decoder: MAEDecoder | CAMAEDecoder,
         | 
| 162 | 
            +
                    pred: torch.nn.Module,
         | 
| 163 | 
            +
                ) -> torch.Tensor:
         | 
| 164 | 
            +
                    """Feed forward the encoder latent through the decoders necessary projections and transformations."""
         | 
| 165 | 
            +
                    decoder_latent_projection = proj(
         | 
| 166 | 
            +
                        encoder_latent
         | 
| 167 | 
            +
                    )  # projection from encoder.embed_dim to decoder.embed_dim
         | 
| 168 | 
            +
                    decoder_tokens = decoder.forward_masked(
         | 
| 169 | 
            +
                        decoder_latent_projection, ind_restore
         | 
| 170 | 
            +
                    )  # decoder.embed_dim output
         | 
| 171 | 
            +
                    predicted_reconstruction = pred(
         | 
| 172 | 
            +
                        decoder_tokens
         | 
| 173 | 
            +
                    )  # linear projection to input dim
         | 
| 174 | 
            +
                    return predicted_reconstruction[:, 1:, :]  # drop class token
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def forward(
         | 
| 177 | 
            +
                    self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None
         | 
| 178 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 179 | 
            +
                    imgs = self.input_norm(imgs)
         | 
| 180 | 
            +
                    latent, mask, ind_restore = self.encoder.forward_masked(
         | 
| 181 | 
            +
                        imgs, self.mask_ratio, constant_noise
         | 
| 182 | 
            +
                    )  # encoder blocks
         | 
| 183 | 
            +
                    reconstruction = self.decode_to_reconstruction(
         | 
| 184 | 
            +
                        latent,
         | 
| 185 | 
            +
                        ind_restore,
         | 
| 186 | 
            +
                        self.encoder_decoder_proj,
         | 
| 187 | 
            +
                        self.decoder,
         | 
| 188 | 
            +
                        self.decoder_pred,
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
                    return latent, reconstruction, mask
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def compute_MAE_loss(
         | 
| 193 | 
            +
                    self,
         | 
| 194 | 
            +
                    reconstruction: torch.Tensor,
         | 
| 195 | 
            +
                    img: torch.Tensor,
         | 
| 196 | 
            +
                    mask: torch.Tensor,
         | 
| 197 | 
            +
                ) -> Tuple[torch.Tensor, Dict[str, float]]:
         | 
| 198 | 
            +
                    """Computes final loss and returns specific values of component losses for metric reporting."""
         | 
| 199 | 
            +
                    loss_dict = {}
         | 
| 200 | 
            +
                    img = self.input_norm(img)
         | 
| 201 | 
            +
                    target_flattened = flatten_images(
         | 
| 202 | 
            +
                        img,
         | 
| 203 | 
            +
                        patch_size=self.patch_size,
         | 
| 204 | 
            +
                        channel_agnostic=self.encoder.channel_agnostic,
         | 
| 205 | 
            +
                    )
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    loss: torch.Tensor = self.loss(
         | 
| 208 | 
            +
                        reconstruction, target_flattened
         | 
| 209 | 
            +
                    )  # should be with MSE or MAE (L1) with reduction='none'
         | 
| 210 | 
            +
                    loss = loss.mean(
         | 
| 211 | 
            +
                        dim=-1
         | 
| 212 | 
            +
                    )  # average over embedding dim -> mean loss per patch (N,L)
         | 
| 213 | 
            +
                    loss = (loss * mask).sum() / mask.sum()  # mean loss on masked patches only
         | 
| 214 | 
            +
                    loss_dict[self.RECON_LOSS] = loss.item()
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # compute fourier loss
         | 
| 217 | 
            +
                    if self.fourier_loss_weight > 0:
         | 
| 218 | 
            +
                        floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened)
         | 
| 219 | 
            +
                        if not self.mask_fourier_loss:
         | 
| 220 | 
            +
                            floss = floss.mean()
         | 
| 221 | 
            +
                        else:
         | 
| 222 | 
            +
                            floss = floss.mean(dim=-1)
         | 
| 223 | 
            +
                            floss = (floss * mask).sum() / mask.sum()
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                        loss_dict[self.FOURIER_LOSS] = floss.item()
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # here we use a mixing factor to keep the loss magnitude appropriate with fourier
         | 
| 228 | 
            +
                    if self.fourier_loss_weight > 0:
         | 
| 229 | 
            +
                        loss = (1 - self.fourier_loss_weight) * loss + (
         | 
| 230 | 
            +
                            self.fourier_loss_weight * floss
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
                    return loss, loss_dict
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
         | 
| 235 | 
            +
                    img = batch["pixels"]
         | 
| 236 | 
            +
                    latent, reconstruction, mask = self(img.clone())
         | 
| 237 | 
            +
                    full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask)
         | 
| 238 | 
            +
                    return {
         | 
| 239 | 
            +
                        "loss": full_loss,
         | 
| 240 | 
            +
                        **loss_dict,  # type: ignore[dict-item]
         | 
| 241 | 
            +
                    }
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
         | 
| 244 | 
            +
                    return self.training_step(batch, batch_idx)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None:
         | 
| 247 | 
            +
                    self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr())
         | 
| 248 | 
            +
                    for key, value in outputs.items():
         | 
| 249 | 
            +
                        if key.endswith("loss"):
         | 
| 250 | 
            +
                            self.metrics[key].update(value)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def on_validation_batch_end(  # type: ignore[override]
         | 
| 253 | 
            +
                    self,
         | 
| 254 | 
            +
                    outputs: TensorDict,
         | 
| 255 | 
            +
                    batch: TensorDict,
         | 
| 256 | 
            +
                    batch_idx: int,
         | 
| 257 | 
            +
                    dataloader_idx: int = 0,
         | 
| 258 | 
            +
                ) -> None:
         | 
| 259 | 
            +
                    super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def predict(self, imgs: torch.Tensor) -> torch.Tensor:
         | 
| 262 | 
            +
                    imgs = self.input_norm(imgs)
         | 
| 263 | 
            +
                    X = self.encoder.vit_backbone.forward_features(
         | 
| 264 | 
            +
                        imgs
         | 
| 265 | 
            +
                    )  # 3d tensor N x num_tokens x dim
         | 
| 266 | 
            +
                    if self.return_channelwise_embeddings:
         | 
| 267 | 
            +
                        N, _, d = X.shape
         | 
| 268 | 
            +
                        num_channels = imgs.shape[1]
         | 
| 269 | 
            +
                        X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d)
         | 
| 270 | 
            +
                        pooled_segments = X_reshaped.mean(
         | 
| 271 | 
            +
                            dim=2
         | 
| 272 | 
            +
                        )  # Resulting shape: (N, num_channels, d)
         | 
| 273 | 
            +
                        latent = pooled_segments.view(N, num_channels * d).contiguous()
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        latent = X[:, 1:, :].mean(dim=1)  # 1 + 256 * C tokens
         | 
| 276 | 
            +
                    return latent
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                def save_pretrained(self, save_directory: str, **kwargs):
         | 
| 279 | 
            +
                    filename = kwargs.pop("filename", "model.safetensors")
         | 
| 280 | 
            +
                    modelpath = f"{save_directory}/{filename}"
         | 
| 281 | 
            +
                    self.config.save_pretrained(save_directory)
         | 
| 282 | 
            +
                    torch.save({"state_dict": self.state_dict()}, modelpath)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                @classmethod
         | 
| 285 | 
            +
                def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
         | 
| 286 | 
            +
                    filename = kwargs.pop("filename", "model.safetensors")
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    modelpath = f"{pretrained_model_name_or_path}/{filename}"
         | 
| 289 | 
            +
                    config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
         | 
| 290 | 
            +
                    state_dict = torch.load(modelpath, map_location="cpu")
         | 
| 291 | 
            +
                    model = cls(config, *model_args, **kwargs)
         | 
| 292 | 
            +
                    model.load_state_dict(state_dict["state_dict"])
         | 
| 293 | 
            +
                    return model
         | 
    	
        loss.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class FourierLoss(nn.Module):
         | 
| 7 | 
            +
                def __init__(
         | 
| 8 | 
            +
                    self,
         | 
| 9 | 
            +
                    use_l1_loss: bool = True,
         | 
| 10 | 
            +
                    num_multimodal_modalities: int = 1,  # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
         | 
| 11 | 
            +
                ) -> None:
         | 
| 12 | 
            +
                    """
         | 
| 13 | 
            +
                    Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
         | 
| 14 | 
            +
                    between the images / their radial histograms.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    We will always set `reduction="none"` and enforce that the computation of any reductions from the
         | 
| 17 | 
            +
                    output of this loss be managed by the model under question.
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    self.loss = (
         | 
| 21 | 
            +
                        nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                    self.num_modalities = num_multimodal_modalities
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         | 
| 26 | 
            +
                    # input = reconstructed image, target = original image
         | 
| 27 | 
            +
                    # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
         | 
| 28 | 
            +
                    flattened_images = len(input.shape) == len(target.shape) == 3
         | 
| 29 | 
            +
                    if flattened_images:
         | 
| 30 | 
            +
                        B, H_W, C = input.shape
         | 
| 31 | 
            +
                        H_W = H_W // self.num_modalities
         | 
| 32 | 
            +
                        four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
         | 
| 33 | 
            +
                        input = input.view(*four_d_shape)
         | 
| 34 | 
            +
                        target = target.view(*four_d_shape)
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        B, C, h, w = input.shape
         | 
| 37 | 
            +
                        H_W = h * w
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    if len(input.shape) != len(target.shape) != 4:
         | 
| 40 | 
            +
                        raise ValueError(
         | 
| 41 | 
            +
                            f"Invalid input shape: got {input.shape} and {target.shape}."
         | 
| 42 | 
            +
                        )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    fft_reconstructed = torch.fft.fft2(input)
         | 
| 45 | 
            +
                    fft_original = torch.fft.fft2(target)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    magnitude_reconstructed = torch.abs(fft_reconstructed)
         | 
| 48 | 
            +
                    magnitude_original = torch.abs(fft_original)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    loss_tensor: torch.Tensor = self.loss(
         | 
| 51 | 
            +
                        magnitude_reconstructed, magnitude_original
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    if (
         | 
| 55 | 
            +
                        flattened_images and not self.num_bins
         | 
| 56 | 
            +
                    ):  # then output loss should be reshaped
         | 
| 57 | 
            +
                        loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    return loss_tensor
         | 
    	
        mae_modules.py
    ADDED
    
    | @@ -0,0 +1,273 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            from functools import partial
         | 
| 3 | 
            +
            from typing import Tuple, Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from timm.models.helpers import checkpoint_seq
         | 
| 8 | 
            +
            from timm.models.vision_transformer import Block, Mlp, VisionTransformer
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from masking import transformer_random_masking
         | 
| 11 | 
            +
            from vit import channel_agnostic_vit
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
         | 
| 14 | 
            +
            # leverage the flattening and unflattening utilities as needed from mae_utils.py.
         | 
| 15 | 
            +
            # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
         | 
| 16 | 
            +
            # As described in the paper, images are self-standardized at the start.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class SelfStandardize(nn.Module):
         | 
| 20 | 
            +
                def __init__(self) -> None:
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.self_standardize = nn.LazyInstanceNorm2d(
         | 
| 23 | 
            +
                        affine=False, track_running_stats=False
         | 
| 24 | 
            +
                    )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def forward(self, pixels: torch.Tensor) -> torch.Tensor:
         | 
| 27 | 
            +
                    x = pixels.float() / 255.0
         | 
| 28 | 
            +
                    return self.self_standardize(x)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class MAEEncoder(nn.Module):
         | 
| 32 | 
            +
                def __init__(
         | 
| 33 | 
            +
                    self,
         | 
| 34 | 
            +
                    vit_backbone: VisionTransformer,
         | 
| 35 | 
            +
                    max_in_chans: int = 6,
         | 
| 36 | 
            +
                    channel_agnostic: bool = False,
         | 
| 37 | 
            +
                ) -> None:
         | 
| 38 | 
            +
                    super().__init__()
         | 
| 39 | 
            +
                    if channel_agnostic:
         | 
| 40 | 
            +
                        self.vit_backbone = channel_agnostic_vit(
         | 
| 41 | 
            +
                            vit_backbone, max_in_chans=max_in_chans
         | 
| 42 | 
            +
                        )
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        self.vit_backbone = vit_backbone
         | 
| 45 | 
            +
                    self.max_in_chans = max_in_chans
         | 
| 46 | 
            +
                    self.channel_agnostic = channel_agnostic
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @property
         | 
| 49 | 
            +
                def embed_dim(self) -> int:
         | 
| 50 | 
            +
                    return int(self.vit_backbone.embed_dim)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 53 | 
            +
                    x = self.vit_backbone.forward_features(x)
         | 
| 54 | 
            +
                    x = self.vit_backbone.forward_head(x)
         | 
| 55 | 
            +
                    return x  # type: ignore[no-any-return]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def forward_masked(
         | 
| 58 | 
            +
                    self,
         | 
| 59 | 
            +
                    x: torch.Tensor,
         | 
| 60 | 
            +
                    mask_ratio: float,
         | 
| 61 | 
            +
                    constant_noise: Union[torch.Tensor, None] = None,
         | 
| 62 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 63 | 
            +
                    x = self.vit_backbone.patch_embed(x)
         | 
| 64 | 
            +
                    x = self.vit_backbone._pos_embed(x)  # adds class token
         | 
| 65 | 
            +
                    x_ = x[:, 1:, :]  # no class token
         | 
| 66 | 
            +
                    x_, mask, ind_restore = transformer_random_masking(
         | 
| 67 | 
            +
                        x_, mask_ratio, constant_noise
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    x = torch.cat([x[:, :1, :], x_], dim=1)  # add class token
         | 
| 70 | 
            +
                    x = self.vit_backbone.norm_pre(x)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
         | 
| 73 | 
            +
                        x = checkpoint_seq(self.vit_backbone.blocks, x)
         | 
| 74 | 
            +
                    else:
         | 
| 75 | 
            +
                        x = self.vit_backbone.blocks(x)
         | 
| 76 | 
            +
                    x = self.vit_backbone.norm(x)
         | 
| 77 | 
            +
                    return x, mask, ind_restore
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            class MAEDecoder(nn.Module):
         | 
| 81 | 
            +
                def __init__(
         | 
| 82 | 
            +
                    self,
         | 
| 83 | 
            +
                    embed_dim: int = 512,
         | 
| 84 | 
            +
                    depth: int = 8,
         | 
| 85 | 
            +
                    num_heads: int = 16,
         | 
| 86 | 
            +
                    mlp_ratio: float = 4,
         | 
| 87 | 
            +
                    qkv_bias: bool = True,
         | 
| 88 | 
            +
                    norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),  # type: ignore[assignment]
         | 
| 89 | 
            +
                ) -> None:
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.embed_dim = embed_dim
         | 
| 92 | 
            +
                    self.pos_embeddings = None  # to be overwritten by MAE class
         | 
| 93 | 
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 94 | 
            +
                    self.blocks = nn.Sequential(
         | 
| 95 | 
            +
                        *[
         | 
| 96 | 
            +
                            Block(
         | 
| 97 | 
            +
                                embed_dim,
         | 
| 98 | 
            +
                                num_heads,
         | 
| 99 | 
            +
                                mlp_ratio,
         | 
| 100 | 
            +
                                qkv_bias=qkv_bias,
         | 
| 101 | 
            +
                                norm_layer=norm_layer,
         | 
| 102 | 
            +
                            )
         | 
| 103 | 
            +
                            for i in range(depth)
         | 
| 104 | 
            +
                        ]
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                    self.norm = norm_layer(embed_dim)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 109 | 
            +
                    x = x + self.pos_embeddings
         | 
| 110 | 
            +
                    x = self.blocks(x)
         | 
| 111 | 
            +
                    x = self.norm(x)
         | 
| 112 | 
            +
                    return x  # type: ignore[no-any-return]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def forward_masked(
         | 
| 115 | 
            +
                    self, x: torch.Tensor, ind_restore: torch.Tensor
         | 
| 116 | 
            +
                ) -> torch.Tensor:
         | 
| 117 | 
            +
                    mask_tokens = self.mask_token.repeat(
         | 
| 118 | 
            +
                        x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
                    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # remove class token
         | 
| 121 | 
            +
                    x_ = torch.gather(
         | 
| 122 | 
            +
                        x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
         | 
| 123 | 
            +
                    )  # unshuffle
         | 
| 124 | 
            +
                    x = torch.cat([x[:, :1, :], x_], dim=1)  # add class token
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    x = x + self.pos_embeddings
         | 
| 127 | 
            +
                    x = self.blocks(x)
         | 
| 128 | 
            +
                    x = self.norm(x)
         | 
| 129 | 
            +
                    return x  # type: ignore[no-any-return]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            class CrossAttention(nn.Module):
         | 
| 133 | 
            +
                def __init__(
         | 
| 134 | 
            +
                    self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
         | 
| 135 | 
            +
                ):
         | 
| 136 | 
            +
                    super().__init__()
         | 
| 137 | 
            +
                    self.num_heads = num_heads
         | 
| 138 | 
            +
                    head_dim = embed_dim // num_heads
         | 
| 139 | 
            +
                    self.scale = head_dim**-0.5
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
         | 
| 142 | 
            +
                    self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 145 | 
            +
                    self.proj = nn.Linear(embed_dim, embed_dim)
         | 
| 146 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def forward(self, x, context):
         | 
| 149 | 
            +
                    B, N, C = x.shape
         | 
| 150 | 
            +
                    _, M, _ = context.shape
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    q = (
         | 
| 153 | 
            +
                        self.q(x)
         | 
| 154 | 
            +
                        .reshape(B, N, self.num_heads, C // self.num_heads)
         | 
| 155 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    kv = (
         | 
| 158 | 
            +
                        self.kv(context)
         | 
| 159 | 
            +
                        .reshape(B, M, 2, self.num_heads, C // self.num_heads)
         | 
| 160 | 
            +
                        .permute(2, 0, 3, 1, 4)
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                    k, v = kv[0], kv[1]
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    attn = (q @ k.transpose(-2, -1)) * self.scale
         | 
| 165 | 
            +
                    attn = attn.softmax(dim=-1)
         | 
| 166 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
         | 
| 169 | 
            +
                    x = self.proj(x)
         | 
| 170 | 
            +
                    x = self.proj_drop(x)
         | 
| 171 | 
            +
                    return x
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            class CAMAEDecoder(nn.Module):
         | 
| 175 | 
            +
                def __init__(
         | 
| 176 | 
            +
                    self,
         | 
| 177 | 
            +
                    num_modalities: int = 6,
         | 
| 178 | 
            +
                    tokens_per_modality: int = 256,
         | 
| 179 | 
            +
                    embed_dim: int = 256,
         | 
| 180 | 
            +
                    depth: int = 2,
         | 
| 181 | 
            +
                    num_heads: int = 16,
         | 
| 182 | 
            +
                    mlp_ratio: float = 4,
         | 
| 183 | 
            +
                    qkv_bias: bool = True,
         | 
| 184 | 
            +
                    norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),  # type: ignore[assignment]
         | 
| 185 | 
            +
                ) -> None:
         | 
| 186 | 
            +
                    super().__init__()
         | 
| 187 | 
            +
                    self.num_modalities = num_modalities
         | 
| 188 | 
            +
                    self.tokens_per_modality = tokens_per_modality
         | 
| 189 | 
            +
                    self.embed_dim = embed_dim
         | 
| 190 | 
            +
                    self.pos_embeddings = None  # to be overwritten by MAE class
         | 
| 191 | 
            +
                    self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         | 
| 192 | 
            +
                    self.placeholder = nn.Parameter(
         | 
| 193 | 
            +
                        torch.zeros(1, 1, embed_dim), requires_grad=False
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
                    self.modality_tokens = nn.ParameterList(
         | 
| 196 | 
            +
                        [
         | 
| 197 | 
            +
                            nn.Parameter(torch.zeros(1, 1, self.embed_dim))
         | 
| 198 | 
            +
                            for modality in range(self.num_modalities)
         | 
| 199 | 
            +
                        ]
         | 
| 200 | 
            +
                    )
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
         | 
| 203 | 
            +
                    self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    self.decoders = nn.ModuleList(
         | 
| 206 | 
            +
                        [
         | 
| 207 | 
            +
                            nn.Sequential(
         | 
| 208 | 
            +
                                *[
         | 
| 209 | 
            +
                                    Block(
         | 
| 210 | 
            +
                                        embed_dim,
         | 
| 211 | 
            +
                                        num_heads,
         | 
| 212 | 
            +
                                        mlp_ratio,
         | 
| 213 | 
            +
                                        qkv_bias=qkv_bias,
         | 
| 214 | 
            +
                                        norm_layer=norm_layer,
         | 
| 215 | 
            +
                                    )
         | 
| 216 | 
            +
                                    for i in range(depth)
         | 
| 217 | 
            +
                                ]
         | 
| 218 | 
            +
                            )
         | 
| 219 | 
            +
                            for modality in range(self.num_modalities)
         | 
| 220 | 
            +
                        ]
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
                    # self.norm = norm_layer(embed_dim)  # we decided to drop the last layer norm
         | 
| 223 | 
            +
                    self.context_norm = norm_layer(embed_dim)
         | 
| 224 | 
            +
                    self.query_norm = norm_layer(embed_dim)
         | 
| 225 | 
            +
                    self.out_norm = norm_layer(embed_dim)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 228 | 
            +
                    x_m_s = []
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    modality_tokens_concat = torch.cat(
         | 
| 231 | 
            +
                        [
         | 
| 232 | 
            +
                            self.placeholder,
         | 
| 233 | 
            +
                        ]  # placeholder for class token
         | 
| 234 | 
            +
                        + [
         | 
| 235 | 
            +
                            m_t.repeat(1, self.tokens_per_modality, 1)
         | 
| 236 | 
            +
                            for m_t in self.modality_tokens
         | 
| 237 | 
            +
                        ],
         | 
| 238 | 
            +
                        dim=1,
         | 
| 239 | 
            +
                    )
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    x = (
         | 
| 242 | 
            +
                        x + self.pos_embeddings + modality_tokens_concat
         | 
| 243 | 
            +
                    )  # add pos and tiled modality tokens
         | 
| 244 | 
            +
                    x_ = x[:, 1:, :]  # no class token
         | 
| 245 | 
            +
                    for m, decoder in enumerate(
         | 
| 246 | 
            +
                        self.decoders
         | 
| 247 | 
            +
                    ):  # iterate through modalities and decoders
         | 
| 248 | 
            +
                        x_m = x_[
         | 
| 249 | 
            +
                            :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
         | 
| 250 | 
            +
                        ]
         | 
| 251 | 
            +
                        x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
         | 
| 252 | 
            +
                        x_m = x_m + self.mlp(self.out_norm(x_m))
         | 
| 253 | 
            +
                        x_m = decoder(x_m)
         | 
| 254 | 
            +
                        x_m_s.append(x_m)
         | 
| 255 | 
            +
                    x_m_s = torch.cat(x_m_s, dim=1)  # concat all tokens
         | 
| 256 | 
            +
                    # x_m_s = self.norm(x_m_s)  # we decided to drop the last layer norm
         | 
| 257 | 
            +
                    x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1)  # add back class token
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    return x_m_s
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def forward_masked(
         | 
| 262 | 
            +
                    self, x: torch.Tensor, ind_restore: torch.Tensor
         | 
| 263 | 
            +
                ) -> torch.Tensor:
         | 
| 264 | 
            +
                    mask_tokens = self.mask_token.repeat(
         | 
| 265 | 
            +
                        x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
         | 
| 266 | 
            +
                    )
         | 
| 267 | 
            +
                    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # remove class token
         | 
| 268 | 
            +
                    x_ = torch.gather(
         | 
| 269 | 
            +
                        x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
         | 
| 270 | 
            +
                    )  # unshuffle
         | 
| 271 | 
            +
                    x = torch.cat([x[:, :1, :], x_], dim=1)  # add class token
         | 
| 272 | 
            +
                    x = self.forward(x)
         | 
| 273 | 
            +
                    return x
         | 
    	
        mae_utils.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def flatten_images(
         | 
| 8 | 
            +
                img: torch.Tensor, patch_size: int, channel_agnostic: bool = False
         | 
| 9 | 
            +
            ) -> torch.Tensor:
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                Flattens 2D images into tokens with the same pixel values
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Parameters
         | 
| 14 | 
            +
                ----------
         | 
| 15 | 
            +
                img : input image tensor (N, C, H, W)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Returns
         | 
| 18 | 
            +
                -------
         | 
| 19 | 
            +
                flattened_img: flattened image tensor (N, L, patch_size**2 * C)
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
         | 
| 23 | 
            +
                    raise ValueError("image H must equal image W and be divisible by patch_size")
         | 
| 24 | 
            +
                in_chans = img.shape[1]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                h = w = int(img.shape[2] // patch_size)
         | 
| 27 | 
            +
                x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                if channel_agnostic:
         | 
| 30 | 
            +
                    x = torch.permute(x, (0, 1, 2, 4, 3, 5))  # NCHPWQ -> NCHWPQ
         | 
| 31 | 
            +
                    x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    x = torch.permute(x, (0, 2, 4, 3, 5, 1))  # NCHPWQ -> NHWPQC
         | 
| 34 | 
            +
                    x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
         | 
| 35 | 
            +
                return x
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def unflatten_tokens(
         | 
| 39 | 
            +
                tokens: torch.Tensor,
         | 
| 40 | 
            +
                patch_size: int,
         | 
| 41 | 
            +
                num_modalities: int = 1,
         | 
| 42 | 
            +
                channel_agnostic: bool = False,
         | 
| 43 | 
            +
            ) -> torch.Tensor:
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Parameters
         | 
| 48 | 
            +
                ----------
         | 
| 49 | 
            +
                tokens : input token tensor (N,L,patch_size**2 * C)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                Returns
         | 
| 52 | 
            +
                -------
         | 
| 53 | 
            +
                img: image tensor (N,C,H,W)
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                if num_modalities > 1 and not channel_agnostic:
         | 
| 56 | 
            +
                    raise ValueError("Multiple modalities requires channel agnostic unflattening.")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
         | 
| 59 | 
            +
                if h * w != (tokens.shape[1] // num_modalities):
         | 
| 60 | 
            +
                    raise ValueError("sqrt of number of tokens not integer")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if channel_agnostic:
         | 
| 63 | 
            +
                    x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
         | 
| 64 | 
            +
                    x = torch.permute(x, (0, 1, 2, 4, 3, 5))  # NCHWPQ -> NCHPWQ
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
         | 
| 67 | 
            +
                    x = torch.permute(x, (0, 5, 1, 3, 2, 4))  # NHWPQC -> NCHPWQ
         | 
| 68 | 
            +
                img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                return img
         | 
    	
        masking.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            from typing import Tuple, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def transformer_random_masking(
         | 
| 8 | 
            +
                x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
         | 
| 9 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                Random mask patches per sample
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Parameters
         | 
| 14 | 
            +
                ----------
         | 
| 15 | 
            +
                x : token tensor (N, L, D)
         | 
| 16 | 
            +
                mask_ratio: float - ratio of image to mask
         | 
| 17 | 
            +
                constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                Returns
         | 
| 20 | 
            +
                -------
         | 
| 21 | 
            +
                x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
         | 
| 22 | 
            +
                mask : binary mask indicated masked tokens (1 where masked) (N, L)
         | 
| 23 | 
            +
                ind_restore : locations of masked tokens, needed for decoder
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                N, L, D = x.shape  # batch, length, dim
         | 
| 27 | 
            +
                len_keep = int(L * (1 - mask_ratio))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                # use random noise to generate batch based random masks
         | 
| 30 | 
            +
                if constant_noise is not None:
         | 
| 31 | 
            +
                    noise = constant_noise
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    noise = torch.rand(N, L, device=x.device)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                shuffled_tokens = torch.argsort(noise, dim=1)  # shuffled index
         | 
| 36 | 
            +
                ind_restore = torch.argsort(shuffled_tokens, dim=1)  # unshuffled index
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # get masked input
         | 
| 39 | 
            +
                tokens_to_keep = shuffled_tokens[:, :len_keep]  # keep the first len_keep indices
         | 
| 40 | 
            +
                x_masked = torch.gather(
         | 
| 41 | 
            +
                    x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # get binary mask used for loss masking: 0 is keep, 1 is remove
         | 
| 45 | 
            +
                mask = torch.ones([N, L], device=x.device)
         | 
| 46 | 
            +
                mask[:, :len_keep] = 0
         | 
| 47 | 
            +
                mask = torch.gather(
         | 
| 48 | 
            +
                    mask, dim=1, index=ind_restore
         | 
| 49 | 
            +
                )  # unshuffle to get the binary mask
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                return x_masked, mask, ind_restore
         | 
    	
        normalizer.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class Normalizer(torch.nn.Module):
         | 
| 5 | 
            +
                def forward(self, pixels: torch.Tensor) -> torch.Tensor:
         | 
| 6 | 
            +
                    pixels = pixels.float()
         | 
| 7 | 
            +
                    return pixels / 255.0
         | 
    	
        pyproject.toml
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [build-system]
         | 
| 2 | 
            +
            requires = ["setuptools >= 61.0"]
         | 
| 3 | 
            +
            build-backend = "setuptools.build_meta"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            [project]
         | 
| 6 | 
            +
            name = "maes_microscopy_project"
         | 
| 7 | 
            +
            version = "0.1.0"
         | 
| 8 | 
            +
            authors = [
         | 
| 9 | 
            +
                {name = "kian-kd", email = "[email protected]"},
         | 
| 10 | 
            +
                {name = "Laksh47", email = "[email protected]"},
         | 
| 11 | 
            +
            ]
         | 
| 12 | 
            +
            requires-python = ">=3.10.4"
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            dependencies = [
         | 
| 15 | 
            +
                "huggingface-hub",
         | 
| 16 | 
            +
                "timm",
         | 
| 17 | 
            +
                "torch>=2.3",
         | 
| 18 | 
            +
                "torchmetrics",
         | 
| 19 | 
            +
                "torchvision",
         | 
| 20 | 
            +
                "tqdm",
         | 
| 21 | 
            +
                "transformers",
         | 
| 22 | 
            +
                "xformers",
         | 
| 23 | 
            +
                "zarr",
         | 
| 24 | 
            +
                "pytorch-lightning>=2.1",
         | 
| 25 | 
            +
                "matplotlib",
         | 
| 26 | 
            +
                "scikit-image",
         | 
| 27 | 
            +
                "ipykernel",
         | 
| 28 | 
            +
                "isort",
         | 
| 29 | 
            +
                "ruff",
         | 
| 30 | 
            +
                "pytest",
         | 
| 31 | 
            +
            ]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            [tool.setuptools]
         | 
| 34 | 
            +
            py-modules = []
         | 
    	
        sample/AA41_s1_1.jp2
    ADDED
    
    |  | 
    	
        sample/AA41_s1_2.jp2
    ADDED
    
    |  | 
    	
        sample/AA41_s1_3.jp2
    ADDED
    
    |  | 
    	
        sample/AA41_s1_4.jp2
    ADDED
    
    |  | 
    	
        sample/AA41_s1_5.jp2
    ADDED
    
    |  | 
    	
        sample/AA41_s1_6.jp2
    ADDED
    
    |  | 
    	
        test_huggingface_mae.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pytest
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from huggingface_mae import MAEModel
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            huggingface_phenombeta_model_dir = "."
         | 
| 7 | 
            +
            # huggingface_modelpath = "recursionpharma/test-pb-model"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @pytest.fixture
         | 
| 11 | 
            +
            def huggingface_model():
         | 
| 12 | 
            +
                # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
         | 
| 13 | 
            +
                # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
         | 
| 14 | 
            +
                huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
         | 
| 15 | 
            +
                huggingface_model.eval()
         | 
| 16 | 
            +
                return huggingface_model
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @pytest.mark.parametrize("C", [1, 4, 6, 11])
         | 
| 20 | 
            +
            @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
         | 
| 21 | 
            +
            def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
         | 
| 22 | 
            +
                example_input_array = torch.randint(
         | 
| 23 | 
            +
                    low=0,
         | 
| 24 | 
            +
                    high=255,
         | 
| 25 | 
            +
                    size=(2, C, 256, 256),
         | 
| 26 | 
            +
                    dtype=torch.uint8,
         | 
| 27 | 
            +
                    device=huggingface_model.device,
         | 
| 28 | 
            +
                )
         | 
| 29 | 
            +
                huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
         | 
| 30 | 
            +
                embeddings = huggingface_model.predict(example_input_array)
         | 
| 31 | 
            +
                expected_output_dim = 384 * C if return_channelwise_embeddings else 384
         | 
| 32 | 
            +
                assert embeddings.shape == (2, expected_output_dim)
         | 
    	
        vit.py
    ADDED
    
    | @@ -0,0 +1,309 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            import timm.models.vision_transformer as vit
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def generate_2d_sincos_pos_embeddings(
         | 
| 7 | 
            +
                embedding_dim: int,
         | 
| 8 | 
            +
                length: int,
         | 
| 9 | 
            +
                scale: float = 10000.0,
         | 
| 10 | 
            +
                use_class_token: bool = True,
         | 
| 11 | 
            +
                num_modality: int = 1,
         | 
| 12 | 
            +
            ) -> torch.nn.Parameter:
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Generate 2Dimensional sin/cosine positional embeddings
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Parameters
         | 
| 17 | 
            +
                ----------
         | 
| 18 | 
            +
                embedding_dim : int
         | 
| 19 | 
            +
                    embedding dimension used in vit
         | 
| 20 | 
            +
                length : int
         | 
| 21 | 
            +
                    number of tokens along height or width of image after patching (assuming square)
         | 
| 22 | 
            +
                scale : float
         | 
| 23 | 
            +
                    scale for sin/cos functions
         | 
| 24 | 
            +
                use_class_token : bool
         | 
| 25 | 
            +
                    True - add zero vector to be added to class_token, False - no vector added
         | 
| 26 | 
            +
                num_modality: number of modalities. If 0, a single modality is assumed.
         | 
| 27 | 
            +
                    Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns
         | 
| 30 | 
            +
                -------
         | 
| 31 | 
            +
                positional_encoding : torch.Tensor
         | 
| 32 | 
            +
                    positional encoding to add to vit patch encodings
         | 
| 33 | 
            +
                    [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
         | 
| 34 | 
            +
                    (w/ or w/o cls_token)
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                linear_positions = torch.arange(length, dtype=torch.float32)
         | 
| 38 | 
            +
                height_mesh, width_mesh = torch.meshgrid(
         | 
| 39 | 
            +
                    linear_positions, linear_positions, indexing="ij"
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                positional_dim = embedding_dim // 4  # accomodate h and w x cos and sin embeddings
         | 
| 42 | 
            +
                positional_weights = (
         | 
| 43 | 
            +
                    torch.arange(positional_dim, dtype=torch.float32) / positional_dim
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                positional_weights = 1.0 / (scale**positional_weights)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                height_weights = torch.outer(height_mesh.flatten(), positional_weights)
         | 
| 48 | 
            +
                width_weights = torch.outer(width_mesh.flatten(), positional_weights)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                positional_encoding = torch.cat(
         | 
| 51 | 
            +
                    [
         | 
| 52 | 
            +
                        torch.sin(height_weights),
         | 
| 53 | 
            +
                        torch.cos(height_weights),
         | 
| 54 | 
            +
                        torch.sin(width_weights),
         | 
| 55 | 
            +
                        torch.cos(width_weights),
         | 
| 56 | 
            +
                    ],
         | 
| 57 | 
            +
                    dim=1,
         | 
| 58 | 
            +
                )[None, :, :]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # repeat positional encoding for multiple channel modalities
         | 
| 61 | 
            +
                positional_encoding = positional_encoding.repeat(1, num_modality, 1)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if use_class_token:
         | 
| 64 | 
            +
                    class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
         | 
| 65 | 
            +
                    positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                return positional_encoding
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class ChannelAgnosticPatchEmbed(vit.PatchEmbed):  # type: ignore[misc]
         | 
| 73 | 
            +
                def __init__(
         | 
| 74 | 
            +
                    self,
         | 
| 75 | 
            +
                    img_size: int,
         | 
| 76 | 
            +
                    patch_size: int,
         | 
| 77 | 
            +
                    embed_dim: int,
         | 
| 78 | 
            +
                    bias: bool = True,
         | 
| 79 | 
            +
                ) -> None:
         | 
| 80 | 
            +
                    super().__init__(
         | 
| 81 | 
            +
                        img_size=img_size,
         | 
| 82 | 
            +
                        patch_size=patch_size,
         | 
| 83 | 
            +
                        in_chans=1,  # in_chans is used by self.proj, which we override anyway
         | 
| 84 | 
            +
                        embed_dim=embed_dim,
         | 
| 85 | 
            +
                        norm_layer=None,
         | 
| 86 | 
            +
                        flatten=False,
         | 
| 87 | 
            +
                        bias=bias,
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                    # channel-agnostic MAE has a single projection for all chans
         | 
| 90 | 
            +
                    self.proj = torch.nn.Conv2d(
         | 
| 91 | 
            +
                        1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 95 | 
            +
                    in_chans = x.shape[1]
         | 
| 96 | 
            +
                    x = torch.stack(
         | 
| 97 | 
            +
                        [self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
         | 
| 98 | 
            +
                    )  # single project for all chans
         | 
| 99 | 
            +
                    x = x.flatten(2).transpose(1, 2)  # BCMHW -> BNC
         | 
| 100 | 
            +
                    return x
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            class ChannelAgnosticViT(vit.VisionTransformer):  # type: ignore[misc]
         | 
| 104 | 
            +
                def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 105 | 
            +
                    # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
         | 
| 106 | 
            +
                    to_cat = []
         | 
| 107 | 
            +
                    if self.cls_token is not None:
         | 
| 108 | 
            +
                        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # TODO: upgrade timm to get access to register tokens
         | 
| 111 | 
            +
                    # if self.vit_backbone.reg_token is not None:
         | 
| 112 | 
            +
                    #     to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
         | 
| 115 | 
            +
                    # this supports having CA-MAEs actually be channel-agnostic at inference time
         | 
| 116 | 
            +
                    if self.no_embed_class:
         | 
| 117 | 
            +
                        x = x + self.pos_embed[:, : x.shape[1]]
         | 
| 118 | 
            +
                        if to_cat:
         | 
| 119 | 
            +
                            x = torch.cat(to_cat + [x], dim=1)
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        if to_cat:
         | 
| 122 | 
            +
                            x = torch.cat(to_cat + [x], dim=1)
         | 
| 123 | 
            +
                        x = x + self.pos_embed[:, : x.shape[1]]
         | 
| 124 | 
            +
                    return self.pos_drop(x)  # type: ignore[no-any-return]
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def channel_agnostic_vit(
         | 
| 128 | 
            +
                vit_backbone: vit.VisionTransformer, max_in_chans: int
         | 
| 129 | 
            +
            ) -> vit.VisionTransformer:
         | 
| 130 | 
            +
                # replace patch embedding with channel-agnostic version
         | 
| 131 | 
            +
                vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
         | 
| 132 | 
            +
                    img_size=vit_backbone.patch_embed.img_size[0],
         | 
| 133 | 
            +
                    patch_size=vit_backbone.patch_embed.patch_size[0],
         | 
| 134 | 
            +
                    embed_dim=vit_backbone.embed_dim,
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # replace positional embedding with channel-agnostic version
         | 
| 138 | 
            +
                vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
         | 
| 139 | 
            +
                    embedding_dim=vit_backbone.embed_dim,
         | 
| 140 | 
            +
                    length=vit_backbone.patch_embed.grid_size[0],
         | 
| 141 | 
            +
                    use_class_token=vit_backbone.cls_token is not None,
         | 
| 142 | 
            +
                    num_modality=max_in_chans,
         | 
| 143 | 
            +
                )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
         | 
| 146 | 
            +
                vit_backbone.__class__ = ChannelAgnosticViT
         | 
| 147 | 
            +
                return vit_backbone
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def sincos_positional_encoding_vit(
         | 
| 151 | 
            +
                vit_backbone: vit.VisionTransformer, scale: float = 10000.0
         | 
| 152 | 
            +
            ) -> vit.VisionTransformer:
         | 
| 153 | 
            +
                """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                Parameters
         | 
| 156 | 
            +
                ----------
         | 
| 157 | 
            +
                vit_backbone : timm.models.vision_transformer.VisionTransformer
         | 
| 158 | 
            +
                    the constructed vision transformer from timm
         | 
| 159 | 
            +
                scale : float (default 10000.0)
         | 
| 160 | 
            +
                    hyperparameter for sincos positional embeddings, recommend keeping at 10,000
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                Returns
         | 
| 163 | 
            +
                -------
         | 
| 164 | 
            +
                timm.models.vision_transformer.VisionTransformer
         | 
| 165 | 
            +
                    the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                # length: number of tokens along height or width of image after patching (assuming square)
         | 
| 168 | 
            +
                length = (
         | 
| 169 | 
            +
                    vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
         | 
| 170 | 
            +
                )
         | 
| 171 | 
            +
                pos_embeddings = generate_2d_sincos_pos_embeddings(
         | 
| 172 | 
            +
                    vit_backbone.embed_dim,
         | 
| 173 | 
            +
                    length=length,
         | 
| 174 | 
            +
                    scale=scale,
         | 
| 175 | 
            +
                    use_class_token=vit_backbone.cls_token is not None,
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
                # note, if the model had weight_init == 'skip', this might get overwritten
         | 
| 178 | 
            +
                vit_backbone.pos_embed = pos_embeddings
         | 
| 179 | 
            +
                return vit_backbone
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def vit_small_patch16_256(**kwargs):
         | 
| 183 | 
            +
                default_kwargs = dict(
         | 
| 184 | 
            +
                    img_size=256,
         | 
| 185 | 
            +
                    in_chans=6,
         | 
| 186 | 
            +
                    num_classes=0,
         | 
| 187 | 
            +
                    fc_norm=None,
         | 
| 188 | 
            +
                    class_token=True,
         | 
| 189 | 
            +
                    drop_path_rate=0.1,
         | 
| 190 | 
            +
                    init_values=0.0001,
         | 
| 191 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 192 | 
            +
                    qkv_bias=False,
         | 
| 193 | 
            +
                    qk_norm=True,
         | 
| 194 | 
            +
                )
         | 
| 195 | 
            +
                for k, v in kwargs.items():
         | 
| 196 | 
            +
                    default_kwargs[k] = v
         | 
| 197 | 
            +
                return vit.vit_small_patch16_224(**default_kwargs)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            def vit_small_patch32_512(**kwargs):
         | 
| 201 | 
            +
                default_kwargs = dict(
         | 
| 202 | 
            +
                    img_size=512,
         | 
| 203 | 
            +
                    in_chans=6,
         | 
| 204 | 
            +
                    num_classes=0,
         | 
| 205 | 
            +
                    fc_norm=None,
         | 
| 206 | 
            +
                    class_token=True,
         | 
| 207 | 
            +
                    drop_path_rate=0.1,
         | 
| 208 | 
            +
                    init_values=0.0001,
         | 
| 209 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 210 | 
            +
                    qkv_bias=False,
         | 
| 211 | 
            +
                    qk_norm=True,
         | 
| 212 | 
            +
                )
         | 
| 213 | 
            +
                for k, v in kwargs.items():
         | 
| 214 | 
            +
                    default_kwargs[k] = v
         | 
| 215 | 
            +
                return vit.vit_small_patch32_384(**default_kwargs)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
            def vit_base_patch8_256(**kwargs):
         | 
| 219 | 
            +
                default_kwargs = dict(
         | 
| 220 | 
            +
                    img_size=256,
         | 
| 221 | 
            +
                    in_chans=6,
         | 
| 222 | 
            +
                    num_classes=0,
         | 
| 223 | 
            +
                    fc_norm=None,
         | 
| 224 | 
            +
                    class_token=True,
         | 
| 225 | 
            +
                    drop_path_rate=0.1,
         | 
| 226 | 
            +
                    init_values=0.0001,
         | 
| 227 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 228 | 
            +
                    qkv_bias=False,
         | 
| 229 | 
            +
                    qk_norm=True,
         | 
| 230 | 
            +
                )
         | 
| 231 | 
            +
                for k, v in kwargs.items():
         | 
| 232 | 
            +
                    default_kwargs[k] = v
         | 
| 233 | 
            +
                return vit.vit_base_patch8_224(**default_kwargs)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            def vit_base_patch16_256(**kwargs):
         | 
| 237 | 
            +
                default_kwargs = dict(
         | 
| 238 | 
            +
                    img_size=256,
         | 
| 239 | 
            +
                    in_chans=6,
         | 
| 240 | 
            +
                    num_classes=0,
         | 
| 241 | 
            +
                    fc_norm=None,
         | 
| 242 | 
            +
                    class_token=True,
         | 
| 243 | 
            +
                    drop_path_rate=0.1,
         | 
| 244 | 
            +
                    init_values=0.0001,
         | 
| 245 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 246 | 
            +
                    qkv_bias=False,
         | 
| 247 | 
            +
                    qk_norm=True,
         | 
| 248 | 
            +
                )
         | 
| 249 | 
            +
                for k, v in kwargs.items():
         | 
| 250 | 
            +
                    default_kwargs[k] = v
         | 
| 251 | 
            +
                return vit.vit_base_patch16_224(**default_kwargs)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            def vit_base_patch32_512(**kwargs):
         | 
| 255 | 
            +
                default_kwargs = dict(
         | 
| 256 | 
            +
                    img_size=512,
         | 
| 257 | 
            +
                    in_chans=6,
         | 
| 258 | 
            +
                    num_classes=0,
         | 
| 259 | 
            +
                    fc_norm=None,
         | 
| 260 | 
            +
                    class_token=True,
         | 
| 261 | 
            +
                    drop_path_rate=0.1,
         | 
| 262 | 
            +
                    init_values=0.0001,
         | 
| 263 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 264 | 
            +
                    qkv_bias=False,
         | 
| 265 | 
            +
                    qk_norm=True,
         | 
| 266 | 
            +
                )
         | 
| 267 | 
            +
                for k, v in kwargs.items():
         | 
| 268 | 
            +
                    default_kwargs[k] = v
         | 
| 269 | 
            +
                return vit.vit_base_patch32_384(**default_kwargs)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
             | 
| 272 | 
            +
            def vit_large_patch8_256(**kwargs):
         | 
| 273 | 
            +
                default_kwargs = dict(
         | 
| 274 | 
            +
                    img_size=256,
         | 
| 275 | 
            +
                    in_chans=6,
         | 
| 276 | 
            +
                    num_classes=0,
         | 
| 277 | 
            +
                    fc_norm=None,
         | 
| 278 | 
            +
                    class_token=True,
         | 
| 279 | 
            +
                    patch_size=8,
         | 
| 280 | 
            +
                    embed_dim=1024,
         | 
| 281 | 
            +
                    depth=24,
         | 
| 282 | 
            +
                    num_heads=16,
         | 
| 283 | 
            +
                    drop_path_rate=0.3,
         | 
| 284 | 
            +
                    init_values=0.0001,
         | 
| 285 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 286 | 
            +
                    qkv_bias=False,
         | 
| 287 | 
            +
                    qk_norm=True,
         | 
| 288 | 
            +
                )
         | 
| 289 | 
            +
                for k, v in kwargs.items():
         | 
| 290 | 
            +
                    default_kwargs[k] = v
         | 
| 291 | 
            +
                return vit.VisionTransformer(**default_kwargs)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            def vit_large_patch16_256(**kwargs):
         | 
| 295 | 
            +
                default_kwargs = dict(
         | 
| 296 | 
            +
                    img_size=256,
         | 
| 297 | 
            +
                    in_chans=6,
         | 
| 298 | 
            +
                    num_classes=0,
         | 
| 299 | 
            +
                    fc_norm=None,
         | 
| 300 | 
            +
                    class_token=True,
         | 
| 301 | 
            +
                    drop_path_rate=0.3,
         | 
| 302 | 
            +
                    init_values=0.0001,
         | 
| 303 | 
            +
                    block_fn=vit.ParallelScalingBlock,
         | 
| 304 | 
            +
                    qkv_bias=False,
         | 
| 305 | 
            +
                    qk_norm=True,
         | 
| 306 | 
            +
                )
         | 
| 307 | 
            +
                for k, v in kwargs.items():
         | 
| 308 | 
            +
                    default_kwargs[k] = v
         | 
| 309 | 
            +
                return vit.vit_large_patch16_384(**default_kwargs)
         | 
    	
        vit_encoder.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # © Recursion Pharmaceuticals 2024
         | 
| 2 | 
            +
            from typing import Dict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import timm.models.vision_transformer as vit
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]:
         | 
| 9 | 
            +
                """This returns the prepped imagenet encoders from timm, not bad for microscopy data."""
         | 
| 10 | 
            +
                vit_backbones = [
         | 
| 11 | 
            +
                    _make_vit(vit.vit_small_patch16_384),
         | 
| 12 | 
            +
                    _make_vit(vit.vit_base_patch16_384),
         | 
| 13 | 
            +
                    _make_vit(vit.vit_base_patch8_224),
         | 
| 14 | 
            +
                    _make_vit(vit.vit_large_patch16_384),
         | 
| 15 | 
            +
                ]
         | 
| 16 | 
            +
                model_names = [
         | 
| 17 | 
            +
                    "vit_small_patch16_384",
         | 
| 18 | 
            +
                    "vit_base_patch16_384",
         | 
| 19 | 
            +
                    "vit_base_patch8_224",
         | 
| 20 | 
            +
                    "vit_large_patch16_384",
         | 
| 21 | 
            +
                ]
         | 
| 22 | 
            +
                imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones))
         | 
| 23 | 
            +
                return {name: model for name, model in zip(model_names, imagenet_encoders)}
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule:
         | 
| 27 | 
            +
                dummy_input = torch.testing.make_tensor(
         | 
| 28 | 
            +
                    (2, 6, 256, 256),
         | 
| 29 | 
            +
                    low=0,
         | 
| 30 | 
            +
                    high=255,
         | 
| 31 | 
            +
                    dtype=torch.uint8,
         | 
| 32 | 
            +
                    device=torch.device("cpu"),
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                encoder = torch.nn.Sequential(
         | 
| 35 | 
            +
                    Normalizer(),
         | 
| 36 | 
            +
                    torch.nn.LazyInstanceNorm2d(
         | 
| 37 | 
            +
                        affine=False, track_running_stats=False
         | 
| 38 | 
            +
                    ),  # this module performs self-standardization, very important
         | 
| 39 | 
            +
                    vit_backbone,
         | 
| 40 | 
            +
                ).to(device="cpu")
         | 
| 41 | 
            +
                _ = encoder(dummy_input)  # get those lazy modules built
         | 
| 42 | 
            +
                return torch.jit.freeze(torch.jit.script(encoder.eval()))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def _make_vit(constructor):
         | 
| 46 | 
            +
                return constructor(
         | 
| 47 | 
            +
                    pretrained=True,  # download imagenet weights
         | 
| 48 | 
            +
                    img_size=256,  # 256x256 crops
         | 
| 49 | 
            +
                    in_chans=6,  # we expect 6-channel microscopy images
         | 
| 50 | 
            +
                    num_classes=0,
         | 
| 51 | 
            +
                    fc_norm=None,
         | 
| 52 | 
            +
                    class_token=True,
         | 
| 53 | 
            +
                    global_pool="avg",  # minimal perf diff btwn "cls" and "avg"
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Normalizer(torch.nn.Module):
         | 
| 58 | 
            +
                def forward(self, pixels: torch.Tensor) -> torch.Tensor:
         | 
| 59 | 
            +
                    pixels = pixels.float()
         | 
| 60 | 
            +
                    pixels /= 255.0
         | 
| 61 | 
            +
                    return pixels
         | 

