편명장/님/(myeongjang.pyeon) commited on
Commit
287a683
·
1 Parent(s): 4083dcc

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
LICENSE CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXAONEPath AI Model License Agreement 1.0 - NC
2
+
3
+ This License Agreement (“Agreement”) is entered into between you (“Licensee”) and LG Management Development Institute Co., Ltd. (“Licensor”), governing the use of the EXAONEPath AI Model (“Model”). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
4
+
5
+ 1. Definitions
6
+ 1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
7
+ 1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
8
+ 1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
9
+ 1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
10
+ 1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
11
+
12
+ 2. License Grant
13
+ 2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
14
+ a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
15
+ b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
16
+ c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
17
+ d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
18
+ 2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
19
+
20
+ 3. Restrictions
21
+ 3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
22
+ 3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
23
+ 3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
24
+ 3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
25
+ a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
26
+ b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
27
+ c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
28
+ d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
29
+
30
+ 4. Ownership
31
+ 4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
32
+ 4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
33
+ 4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
34
+
35
+ 5. No Warranty
36
+ 5.1 “As-Is” Basis: The Model, Derivatives, and Output are provided on an “as-is” and “as-available” basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
37
+ 5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licensee’s requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
38
+ 5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
39
+
40
+ 6. Limitation of Liability
41
+ 6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
42
+ 6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
43
+
44
+ 7. Termination
45
+ 7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licensee’s rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
46
+ 7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
47
+ 7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
48
+
49
+ 8. Governing Law
50
+ 8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
51
+ 8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
52
+
53
+ 9. Alterations
54
+ 9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensor’s website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
55
+ 9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
56
+
57
+ By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
README.md CHANGED
@@ -1,5 +1,35 @@
1
- ---
2
- license: other
3
- license_name: license
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ MSI classification of CRC tumors using EXAONEPath 1.0.0 - a patch-level foundation model for pathology
3
+
4
+ # Model Overview
5
+ This model serves as a reference for predicting MSI status using CRC (colorectal cancer) tumor images as input. When the model receives an H&E-stained whole slide image as input, it removes artifacts observed in the image and extracts only tissue-related objects. These objects are then reconstructed into a set of tiles with a size of 256 by 256 pixels at an mpp (micron per pixel) of 0.5.
6
+
7
+ The tiles pass through the EXAONEPath v1.0 patch-level foundation model (https://huggingface.co/LGAI-EXAONE/EXAONEPath), which converts them into a set of features. These features are then integrated into a slide-level feature representation through an aggregator. Finally, a linear classifier predicts the MSI status (MSS or MSI-H/L).
8
+
9
+ The model achieves an average performance of AUROC 0.93 on TCGA-COAD + TCGA-READ data and 0.84 on in-house data.
10
+
11
+ This open-source release aims to demonstrate that the combination of EXAONEPath and the aggregator can effectively perform pathological tasks. It is hoped that this source code will serve as an important reference not only for CRC MSI prediction but also as an image-based solution for various disease-related problems, including molecular subtyping, tumor subtyping, and mutation prediction.
12
+
13
+ This open-source code is designed to be compatible with the MONAI framework (https://monai.io/), and users must check the accompanying license before use.
14
+
15
+ ![EXAONEPath](assets/exaonepath_v1.png)
16
+
17
+ # Input and Output Formats
18
+ The input is a `.svs` file containing whole slide image.
19
+ The output is an array with probabilities for each of the two classes (MSS / MSI(H/L)).
20
+
21
+ # Execute Inference
22
+ The inference can be executed as follows
23
+ 1. Copy your `.svs` files into `samples` directory
24
+ 2. Run inference
25
+ ```
26
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json configs/logging.conf
27
+ ```
28
+
29
+ ## Contact
30
+ LG AI Research Technical Support: <a href="mailto:[email protected]">[email protected]</a>
31
+
32
+ # License
33
+ Copyright (c) LG AI Research
34
+
35
+ The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](LICENSE).
assets/exaonepath_v1.png ADDED
configs/inference.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imports:
2
+ - $import os
3
+ - $import glob
4
+ - $import torch
5
+ - $import scripts
6
+ - $import scripts.inference
7
+
8
+ device: '$torch.device("cuda" if torch.cuda.is_available() else "cpu")'
9
+
10
+ model_config:
11
+ _target_: scripts.exaonepath.EXAONEPathV1Downstream
12
+ step_size: 256
13
+ patch_size: 256
14
+ macenko: true
15
+ device: '$@device'
16
+
17
+ model: '$@model_config'
18
+
19
+ input_dir: 'samples'
20
+ input_files: '$sorted(glob.glob(@input_dir+''/*.svs''))'
21
+
22
+ root_dir: '$os.path.dirname(os.path.dirname(scripts.__file__))'
23
+
24
+ inference:
25
+ - [email protected]_state_dict(torch.load(os.path.join(@root_dir, "models/exaonepath_v1.0.0_msi.pt")))
26
+ - $scripts.inference.infer(@model, @input_files)
configs/metadata.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.1.0",
3
+ "changelog": {
4
+ "0.1.0": "Initial release"
5
+ },
6
+ "monai_version": "1.5",
7
+ "pytorch_version": "2.6.0",
8
+ "numpy_version": "1.26.4",
9
+ "required_packages_version": {
10
+ "torchvision": "0.21.0",
11
+ "openslide-bin": "4.0.0.6",
12
+ "openslide-python": "1.4.1",
13
+ "pandas": "2.2.3",
14
+ "torchstain": "1.4.1",
15
+ "opencv-python-headless": "4.11.0.86",
16
+ "einops": "0.8.1"
17
+ },
18
+ "name": "MSI predictor from WSI using EXAONEPath v1.0",
19
+ "task": "MSI Classification of Colorectal Cancer (CRC) Tumors",
20
+ "description": "MSI classification of CRC tumors using EXAONEPath 1.0.0 - a patch-level foundation model for pathology",
21
+ "authors": "LG AI Research",
22
+ "copyright": "Copyright (c) LG AI Research",
23
+ "data_source": "Whole Slide Images",
24
+ "data_type": "float32",
25
+ "intended_use": "Research only",
26
+ "network_data_format": {
27
+ "inputs": {
28
+ "image": {
29
+ "type": "image",
30
+ "format": "svs",
31
+ "modality": "pathology"
32
+ }
33
+ },
34
+ "outputs": {
35
+ "pred": {
36
+ "type": "probabilities",
37
+ "format": "classes",
38
+ "num_channels": 2,
39
+ "spatial_shape": [2],
40
+ "dtype": "float32",
41
+ "value_range": [0, 1],
42
+ "is_patch_data": false,
43
+ "channel_def": {
44
+ "0": "MSS",
45
+ "1": "MSI"
46
+ }
47
+ }
48
+ }
49
+ }
50
+ }
models/.gitkeep ADDED
File without changes
models/exaonepath_v1.0.0_msi.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c7c3199335922de3c156ede12e2d655b23ec988529be18611981832d6ab8095
3
+ size 625193090
models/macenko_param.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74a6e178f1d3ea897152e544857eb880045a8ffb2be1e05b1639386083e68435
3
+ size 999
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
+ torch==2.6.0+cu124
4
+ torchvision==0.21.0+cu124
5
+ numpy==1.26.4
6
+ opencv-python-headless==4.11.0.86
7
+ openslide-bin==4.0.0.6
8
+ openslide-python==1.4.1
9
+ monai-weekly[ignite,pyyaml]==1.5.dev2506
10
+ pandas==2.2.3
11
+ torchstain==1.4.1
12
+ fire==0.7.0
13
+ einops==0.8.1
scripts/__init__.py ADDED
File without changes
scripts/aggregator.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from math import log, pi
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from einops.layers.torch import Rearrange, Reduce
8
+ from torch import einsum, nn
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def default(val, d):
16
+ return val if exists(val) else d
17
+
18
+
19
+ def cache_fn(f):
20
+ cache = dict()
21
+
22
+ @wraps(f)
23
+ def cached_fn(*args, _cache=True, key=None, **kwargs):
24
+ if not _cache:
25
+ return f(*args, **kwargs)
26
+ nonlocal cache
27
+ if key in cache:
28
+ return cache[key]
29
+ result = f(*args, **kwargs)
30
+ cache[key] = result
31
+ return result
32
+
33
+ return cached_fn
34
+
35
+
36
+ def fourier_encode(x, max_freq, num_bands=4):
37
+ x = x.unsqueeze(-1)
38
+ device, dtype, orig_x = x.device, x.dtype, x
39
+
40
+ scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype)
41
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
42
+
43
+ x = x * scales * pi
44
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
45
+ x = torch.cat((x, orig_x), dim=-1)
46
+ return x
47
+
48
+
49
+ class PreNorm(nn.Module):
50
+ def __init__(self, dim, fn, context_dim=None):
51
+ super().__init__()
52
+ self.fn = fn
53
+ self.norm = nn.LayerNorm(dim)
54
+ self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
55
+
56
+ def forward(self, x, **kwargs):
57
+ x = self.norm(x)
58
+
59
+ if exists(self.norm_context):
60
+ context = kwargs["context"]
61
+ normed_context = self.norm_context(context)
62
+ kwargs.update(context=normed_context)
63
+
64
+ return self.fn(x, **kwargs)
65
+
66
+
67
+ class GEGLU(nn.Module):
68
+ def forward(self, x):
69
+ x, gates = x.chunk(2, dim=-1)
70
+ return x * F.gelu(gates)
71
+
72
+
73
+ class FeedForward(nn.Module):
74
+ def __init__(self, dim, mult=4, dropout=0.0):
75
+ super().__init__()
76
+ self.net = nn.Sequential(
77
+ nn.Linear(dim, dim * mult * 2),
78
+ GEGLU(),
79
+ nn.Linear(dim * mult, dim),
80
+ nn.Dropout(dropout),
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.net(x)
85
+
86
+
87
+ class Attention(nn.Module):
88
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
89
+ super().__init__()
90
+ inner_dim = dim_head * heads
91
+ context_dim = default(context_dim, query_dim)
92
+
93
+ self.scale = dim_head**-0.5
94
+ self.heads = heads
95
+
96
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
97
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
98
+
99
+ self.dropout = nn.Dropout(dropout)
100
+ self.to_out = nn.Linear(inner_dim, query_dim)
101
+
102
+ def forward(self, x, context=None, mask=None):
103
+ h = self.heads
104
+
105
+ q = self.to_q(x)
106
+ context = default(context, x)
107
+ k, v = self.to_kv(context).chunk(2, dim=-1)
108
+
109
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
110
+
111
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
112
+
113
+ if exists(mask):
114
+ mask = rearrange(mask, "b ... -> b (...)")
115
+ max_neg_value = -torch.finfo(sim.dtype).max
116
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
117
+ sim.masked_fill_(~mask, max_neg_value)
118
+
119
+ # attention, what we cannot get enough of
120
+ attn = sim.softmax(dim=-1)
121
+ attn = self.dropout(attn)
122
+
123
+ out = einsum("b i j, b j d -> b i d", attn, v)
124
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
125
+ return self.to_out(out)
126
+
127
+
128
+ class Perceiver(nn.Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ num_freq_bands,
133
+ depth,
134
+ max_freq,
135
+ input_channels=3,
136
+ input_axis=2,
137
+ num_latents=512,
138
+ latent_dim=512,
139
+ cross_heads=1,
140
+ latent_heads=8,
141
+ cross_dim_head=64,
142
+ latent_dim_head=64,
143
+ num_classes=1000,
144
+ attn_dropout=0.0,
145
+ ff_dropout=0.0,
146
+ weight_tie_layers=False,
147
+ fourier_encode_data=True,
148
+ self_per_cross_attn=1,
149
+ final_classifier_head=True,
150
+ pool="mean",
151
+ latent_init=None,
152
+ ):
153
+ """The shape of the final attention mechanism will be:
154
+ depth * (cross attention -> self_per_cross_attn * self attention)
155
+
156
+ Args:
157
+ num_freq_bands: Number of freq bands, with original value (2 * K + 1)
158
+ depth: Depth of net.
159
+ max_freq: Maximum frequency, hyperparameter depending on how
160
+ fine the data is.
161
+ freq_base: Base for the frequency
162
+ input_channels: Number of channels for each token of the input.
163
+ input_axis: Number of axes for input data (2 for images, 3 for video)
164
+ num_latents: Number of latents, or induced set points, or centroids.
165
+ Different papers giving it different names.
166
+ latent_dim: Latent dimension.
167
+ cross_heads: Number of heads for cross attention. Paper said 1.
168
+ latent_heads: Number of heads for latent self attention, 8.
169
+ cross_dim_head: Number of dimensions per cross attention head.
170
+ latent_dim_head: Number of dimensions per latent self attention head.
171
+ num_classes: Output number of classes.
172
+ attn_dropout: Attention dropout
173
+ ff_dropout: Feedforward dropout
174
+ weight_tie_layers: Whether to weight tie layers (optional).
175
+ fourier_encode_data: Whether to auto-fourier encode the data, using
176
+ the input_axis given. defaults to True, but can be turned off
177
+ if you are fourier encoding the data yourself.
178
+ self_per_cross_attn: Number of self attention blocks per cross attn.
179
+ final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
180
+ """
181
+ super().__init__()
182
+ self.input_axis = input_axis
183
+ self.max_freq = max_freq
184
+ self.num_freq_bands = num_freq_bands
185
+ self.self_per_cross_attn = self_per_cross_attn
186
+
187
+ self.fourier_encode_data = fourier_encode_data
188
+ fourier_channels = (
189
+ (input_axis * ((num_freq_bands * 2) + 1)) * 2 if fourier_encode_data else 0
190
+ )
191
+ input_dim = fourier_channels + input_channels
192
+
193
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
194
+ if latent_init is not None:
195
+ latent_init_feat = torch.load(latent_init)
196
+ if type(latent_init_feat) != torch.Tensor:
197
+ latent_init_feat = torch.Tensor(latent_init_feat)
198
+ if len(latent_init_feat.shape) == 3:
199
+ latent_init_feat = latent_init_feat[0]
200
+ with torch.no_grad():
201
+ self.latents.copy_(latent_init_feat)
202
+ print(f"load latent feature: , {latent_init}")
203
+
204
+ get_cross_attn = lambda: PreNorm(
205
+ latent_dim,
206
+ Attention(
207
+ latent_dim,
208
+ input_dim,
209
+ heads=cross_heads,
210
+ dim_head=cross_dim_head,
211
+ dropout=attn_dropout,
212
+ ),
213
+ context_dim=input_dim,
214
+ )
215
+ get_cross_ff = lambda: PreNorm(
216
+ latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
217
+ )
218
+ get_latent_attn = lambda: PreNorm(
219
+ latent_dim,
220
+ Attention(
221
+ latent_dim,
222
+ heads=latent_heads,
223
+ dim_head=latent_dim_head,
224
+ dropout=attn_dropout,
225
+ ),
226
+ )
227
+ get_latent_ff = lambda: PreNorm(
228
+ latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
229
+ )
230
+
231
+ get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
232
+ cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)
233
+ )
234
+
235
+ self.layers = nn.ModuleList([])
236
+ for i in range(depth):
237
+ should_cache = i > 0 and weight_tie_layers
238
+ cache_args = {"_cache": should_cache}
239
+
240
+ self_attns = nn.ModuleList([])
241
+
242
+ for block_ind in range(self_per_cross_attn):
243
+ self_attns.append(
244
+ nn.ModuleList(
245
+ [
246
+ get_latent_attn(**cache_args, key=block_ind),
247
+ get_latent_ff(**cache_args, key=block_ind),
248
+ ]
249
+ )
250
+ )
251
+ if self_per_cross_attn == 0:
252
+ self_attns.append(get_latent_ff(**cache_args, key=block_ind))
253
+
254
+ self.layers.append(
255
+ nn.ModuleList(
256
+ [
257
+ get_cross_attn(**cache_args),
258
+ get_cross_ff(**cache_args),
259
+ self_attns,
260
+ ]
261
+ )
262
+ )
263
+
264
+ if final_classifier_head:
265
+ if pool == "cat":
266
+ self.to_logits = nn.Sequential(
267
+ Rearrange("b n d -> b (n d)"),
268
+ nn.LayerNorm(num_latents * latent_dim),
269
+ nn.Linear(num_latents * latent_dim, num_classes),
270
+ )
271
+ elif pool == "mlp":
272
+ self.to_logits = nn.Sequential(
273
+ Reduce("b n d -> b d", "mean"),
274
+ nn.LayerNorm(latent_dim),
275
+ nn.Linear(latent_dim, latent_dim),
276
+ nn.ReLU(),
277
+ nn.LayerNorm(latent_dim),
278
+ nn.Linear(latent_dim, num_classes),
279
+ )
280
+ else:
281
+ self.to_logits = nn.Sequential(
282
+ Reduce("b n d -> b d", pool),
283
+ nn.LayerNorm(latent_dim),
284
+ nn.Linear(latent_dim, num_classes),
285
+ )
286
+
287
+ def forward(self, h, label=None, mask=None, pretrain=False, coords=None):
288
+ b, *axis, _, device, dtype = *h.shape, h.device, h.dtype
289
+ assert (
290
+ len(axis) == self.input_axis
291
+ ), "input data must have the right number of axis"
292
+
293
+ if self.fourier_encode_data:
294
+ # calculate fourier encoded positions in the range of [-1, 1], for all axis
295
+
296
+ # axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
297
+ # pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
298
+ # enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
299
+ # enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
300
+ # enc_pos = repeat(enc_pos, '... -> b ...', b = b)
301
+
302
+ enc_pos = fourier_encode(coords, self.max_freq, self.num_freq_bands)
303
+ enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
304
+
305
+ h = torch.cat((h, enc_pos), dim=-1)
306
+
307
+ # concat to channels of data and flatten axis
308
+
309
+ h = rearrange(h, "b ... d -> b (...) d")
310
+
311
+ x = repeat(self.latents, "n d -> b n d", b=b)
312
+
313
+ # layers
314
+
315
+ for cross_attn, cross_ff, self_attns in self.layers:
316
+ x = cross_attn(x, context=h, mask=mask) + x
317
+ x = cross_ff(x) + x
318
+
319
+ if self.self_per_cross_attn > 0:
320
+ for self_attn, self_ff in self_attns:
321
+ x = self_attn(x) + x
322
+ x = self_ff(x) + x
323
+ else:
324
+ x = self_attns[0](x) + x
325
+ # allow for fetching embeddings
326
+
327
+ if pretrain:
328
+ return x.mean(dim=1)
329
+
330
+ # to logits
331
+ logits = self.to_logits(x)
332
+ Y_hat = torch.topk(logits, 1, dim=1)[1]
333
+ Y_prob = F.softmax(logits, dim=1)
334
+ return logits, Y_prob, Y_hat
scripts/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ CLASS_NAMES = ["MSS", "MSI"]
scripts/dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from openslide import OpenSlide
4
+ from scripts.preprocessor import MacenkoNormalizer, preprocessor
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class WSIPatchDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ coords,
12
+ wsi_path,
13
+ pretrained=False,
14
+ patch_size=256,
15
+ patch_level=0,
16
+ macenko=True,
17
+ ):
18
+ self.pretrained = pretrained
19
+ self.wsi = OpenSlide(wsi_path)
20
+ self.patch_size = patch_size
21
+ self.patch_level = patch_level
22
+
23
+ if macenko:
24
+ normalizer = MacenkoNormalizer(
25
+ target_path=os.path.join(
26
+ os.path.dirname(os.path.dirname(os.path.join(__file__))),
27
+ "models",
28
+ "macenko_param.pt",
29
+ )
30
+ )
31
+ else:
32
+ normalizer = None
33
+
34
+ self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
35
+ self.coords = coords
36
+ self.length = len(self.coords)
37
+
38
+ def __len__(self):
39
+ return self.length
40
+
41
+ def __getitem__(self, idx):
42
+ coord = self.coords[idx]
43
+ img = self.wsi.read_region(
44
+ coord, self.patch_level, (self.patch_size, self.patch_size)
45
+ ).convert("RGB")
46
+ img = self.roi_transforms(img)
47
+ return img
scripts/exaonepath.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from scripts.aggregator import Perceiver
5
+ from scripts.dataset import WSIPatchDataset
6
+ from scripts.feature_extractor import vit_base
7
+ from scripts.wsi_utils import extract_tissue_patch_coords
8
+ from torch.utils.data import DataLoader
9
+
10
+
11
+ class EXAONEPathV1Downstream(nn.Module):
12
+ def __init__(
13
+ self, device: torch.device, step_size=256, patch_size=256, macenko=True
14
+ ):
15
+ super(EXAONEPathV1Downstream, self).__init__()
16
+ self.step_size = step_size
17
+ self.patch_size = patch_size
18
+ self.macenko = macenko
19
+ self.device = device
20
+
21
+ self.feature_extractor = vit_base()
22
+ self.feature_extractor = self.feature_extractor
23
+ self.feature_extractor = self.feature_extractor.to(self.device)
24
+ self.feature_extractor.eval()
25
+
26
+ self.agg_model = Perceiver(
27
+ input_channels=768,
28
+ input_axis=1,
29
+ num_freq_bands=6,
30
+ max_freq=10.0,
31
+ depth=6,
32
+ num_latents=256,
33
+ latent_dim=512,
34
+ cross_heads=1,
35
+ latent_heads=8,
36
+ cross_dim_head=64,
37
+ latent_dim_head=64,
38
+ num_classes=2,
39
+ fourier_encode_data=False,
40
+ self_per_cross_attn=2,
41
+ pool="mean",
42
+ )
43
+ self.agg_model.to(self.device)
44
+ self.agg_model.eval()
45
+
46
+ @torch.no_grad()
47
+ def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
48
+ # Extract patches
49
+ coords = extract_tissue_patch_coords(
50
+ svs_path, patch_size=self.patch_size, step_size=self.step_size
51
+ )
52
+
53
+ # Extract patch-level features
54
+ self.feature_extractor.eval()
55
+ patch_dataset = WSIPatchDataset(
56
+ coords=coords,
57
+ wsi_path=svs_path,
58
+ pretrained=True,
59
+ macenko=self.macenko,
60
+ patch_size=self.patch_size,
61
+ )
62
+ patch_loader = DataLoader(
63
+ dataset=patch_dataset,
64
+ batch_size=feature_extractor_batch_size,
65
+ num_workers=(
66
+ feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
67
+ ),
68
+ pin_memory=self.device.type == "cuda",
69
+ )
70
+ features_list = []
71
+ for count, patches in enumerate(patch_loader):
72
+ print(
73
+ f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
74
+ end="\r",
75
+ )
76
+ patches = patches.to(self.device, non_blocking=True)
77
+
78
+ feature = self.feature_extractor(patches) # [B, 1024]
79
+ feature /= feature.norm(dim=-1, keepdim=True) # use normalized feature
80
+ feature = feature.to("cpu", non_blocking=True)
81
+ features_list.append(feature)
82
+ print("")
83
+ print("Feature extraction finished")
84
+
85
+ features = torch.cat(features_list)
86
+
87
+ # Aggregate features
88
+ self.agg_model.eval()
89
+ logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device))
90
+ probs = Y_prob[0].cpu()
91
+
92
+ return probs
scripts/feature_extractor.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class DropPath(nn.Module):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
11
+
12
+ def __init__(self, drop_prob=None):
13
+ super(DropPath, self).__init__()
14
+ self.drop_prob = drop_prob
15
+
16
+ def forward(self, x):
17
+ return _drop_path(x, self.drop_prob, self.training)
18
+
19
+
20
+ class Mlp(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_features,
24
+ hidden_features=None,
25
+ out_features=None,
26
+ act_layer=nn.GELU,
27
+ drop=0.0,
28
+ ):
29
+ super().__init__()
30
+ out_features = out_features or in_features
31
+ hidden_features = hidden_features or in_features
32
+ self.fc1 = nn.Linear(in_features, hidden_features)
33
+ self.act = act_layer()
34
+ self.fc2 = nn.Linear(hidden_features, out_features)
35
+ self.drop = nn.Dropout(drop)
36
+
37
+ def forward(self, x):
38
+ x = self.fc1(x)
39
+ x = self.act(x)
40
+ x = self.drop(x)
41
+ x = self.fc2(x)
42
+ x = self.drop(x)
43
+ return x
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ num_heads=8,
51
+ qkv_bias=False,
52
+ qk_scale=None,
53
+ attn_drop=0.0,
54
+ proj_drop=0.0,
55
+ ):
56
+ super().__init__()
57
+ self.num_heads = num_heads
58
+ head_dim = dim // num_heads
59
+ self.scale = qk_scale or head_dim**-0.5
60
+
61
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62
+ self.attn_drop = nn.Dropout(attn_drop)
63
+ self.proj = nn.Linear(dim, dim)
64
+ self.proj_drop = nn.Dropout(proj_drop)
65
+
66
+ def forward(self, x):
67
+ B, N, C = x.shape
68
+ qkv = (
69
+ self.qkv(x)
70
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
71
+ .permute(2, 0, 3, 1, 4)
72
+ )
73
+ q, k, v = qkv[0], qkv[1], qkv[2]
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x, attn
83
+
84
+
85
+ class Block(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ num_heads,
90
+ mlp_ratio=4.0,
91
+ qkv_bias=False,
92
+ qk_scale=None,
93
+ drop=0.0,
94
+ attn_drop=0.0,
95
+ drop_path=0.0,
96
+ act_layer=nn.GELU,
97
+ norm_layer=nn.LayerNorm,
98
+ ):
99
+ super().__init__()
100
+ self.norm1 = norm_layer(dim)
101
+ self.attn = Attention(
102
+ dim,
103
+ num_heads=num_heads,
104
+ qkv_bias=qkv_bias,
105
+ qk_scale=qk_scale,
106
+ attn_drop=attn_drop,
107
+ proj_drop=drop,
108
+ )
109
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
110
+ self.norm2 = norm_layer(dim)
111
+ mlp_hidden_dim = int(dim * mlp_ratio)
112
+ self.mlp = Mlp(
113
+ in_features=dim,
114
+ hidden_features=mlp_hidden_dim,
115
+ act_layer=act_layer,
116
+ drop=drop,
117
+ )
118
+
119
+ def forward(self, x, return_attention=False):
120
+ y, attn = self.attn(self.norm1(x))
121
+ if return_attention:
122
+ return attn
123
+ x = x + self.drop_path(y)
124
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
125
+ return x
126
+
127
+
128
+ class PatchEmbed(nn.Module):
129
+ """Image to Patch Embedding"""
130
+
131
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132
+ super().__init__()
133
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
134
+ self.img_size = img_size
135
+ self.patch_size = patch_size
136
+ self.num_patches = num_patches
137
+
138
+ self.proj = nn.Conv2d(
139
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
140
+ )
141
+
142
+ def forward(self, x):
143
+ B, C, H, W = x.shape
144
+ x = self.proj(x).flatten(2).transpose(1, 2)
145
+ return x
146
+
147
+
148
+ class VisionTransformer(nn.Module):
149
+ def __init__(
150
+ self,
151
+ img_size=[224],
152
+ patch_size=16,
153
+ in_chans=3,
154
+ num_classes=0,
155
+ embed_dim=768,
156
+ depth=12,
157
+ num_heads=12,
158
+ mlp_ratio=4.0,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ drop_rate=0.0,
162
+ attn_drop_rate=0.0,
163
+ drop_path_rate=0.0,
164
+ norm_layer=nn.LayerNorm,
165
+ **kwargs
166
+ ):
167
+ super().__init__()
168
+ self.num_features = self.embed_dim = embed_dim
169
+
170
+ self.patch_embed = PatchEmbed(
171
+ img_size=img_size[0],
172
+ patch_size=patch_size,
173
+ in_chans=in_chans,
174
+ embed_dim=embed_dim,
175
+ )
176
+ num_patches = self.patch_embed.num_patches
177
+
178
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
179
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
180
+ self.pos_drop = nn.Dropout(p=drop_rate)
181
+
182
+ dpr = [
183
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
184
+ ] # stochastic depth decay rule
185
+ self.blocks = nn.ModuleList(
186
+ [
187
+ Block(
188
+ dim=embed_dim,
189
+ num_heads=num_heads,
190
+ mlp_ratio=mlp_ratio,
191
+ qkv_bias=qkv_bias,
192
+ qk_scale=qk_scale,
193
+ drop=drop_rate,
194
+ attn_drop=attn_drop_rate,
195
+ drop_path=dpr[i],
196
+ norm_layer=norm_layer,
197
+ )
198
+ for i in range(depth)
199
+ ]
200
+ )
201
+ self.norm = norm_layer(embed_dim)
202
+
203
+ # Classifier head
204
+ self.head = (
205
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
206
+ )
207
+
208
+ _trunc_normal_(self.pos_embed, std=0.02)
209
+ _trunc_normal_(self.cls_token, std=0.02)
210
+ self.apply(self._init_weights)
211
+
212
+ def _init_weights(self, m):
213
+ if isinstance(m, nn.Linear):
214
+ _trunc_normal_(m.weight, std=0.02)
215
+ if isinstance(m, nn.Linear) and m.bias is not None:
216
+ nn.init.constant_(m.bias, 0)
217
+ elif isinstance(m, nn.LayerNorm):
218
+ nn.init.constant_(m.bias, 0)
219
+ nn.init.constant_(m.weight, 1.0)
220
+
221
+ def interpolate_pos_encoding(self, x, w, h):
222
+ npatch = x.shape[1] - 1
223
+ N = self.pos_embed.shape[1] - 1
224
+ if npatch == N and w == h:
225
+ return self.pos_embed
226
+ class_pos_embed = self.pos_embed[:, 0]
227
+ patch_pos_embed = self.pos_embed[:, 1:]
228
+ dim = x.shape[-1]
229
+ w0 = w // self.patch_embed.patch_size
230
+ h0 = h // self.patch_embed.patch_size
231
+ # we add a small number to avoid floating point error in the interpolation
232
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
233
+ w0, h0 = w0 + 0.1, h0 + 0.1
234
+ patch_pos_embed = nn.functional.interpolate(
235
+ patch_pos_embed.reshape(
236
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
237
+ ).permute(0, 3, 1, 2),
238
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
239
+ # size=(int(w0), int(h0)),
240
+ mode="bicubic",
241
+ )
242
+ assert (
243
+ int(w0) == patch_pos_embed.shape[-2]
244
+ and int(h0) == patch_pos_embed.shape[-1]
245
+ )
246
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
247
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
248
+
249
+ def prepare_tokens(self, x):
250
+ B, nc, w, h = x.shape
251
+ x = self.patch_embed(x) # patch linear embedding
252
+
253
+ # add the [CLS] token to the embed patch tokens
254
+ cls_tokens = self.cls_token.expand(B, -1, -1)
255
+ x = torch.cat((cls_tokens, x), dim=1)
256
+
257
+ # add positional encoding to each token
258
+ x = x + self.interpolate_pos_encoding(x, w, h)
259
+
260
+ return self.pos_drop(x)
261
+
262
+ def forward(self, x):
263
+ x = self.prepare_tokens(x)
264
+ for blk in self.blocks:
265
+ x = blk(x)
266
+ x = self.norm(x)
267
+ # print(x.type())
268
+ return x[:, 0]
269
+
270
+ def get_last_selfattention(self, x):
271
+ x = self.prepare_tokens(x)
272
+ for i, blk in enumerate(self.blocks):
273
+ if i < len(self.blocks) - 1:
274
+ x = blk(x)
275
+ else:
276
+ # return attention of the last block
277
+ return blk(x, return_attention=True)
278
+
279
+ def get_intermediate_layers(self, x, n=1):
280
+ x = self.prepare_tokens(x)
281
+ # we return the output tokens from the `n` last blocks
282
+ output = []
283
+ for i, blk in enumerate(self.blocks):
284
+ x = blk(x)
285
+ if len(self.blocks) - i <= n:
286
+ output.append(self.norm(x))
287
+ return output
288
+
289
+
290
+ def vit_base(patch_size=16, **kwargs):
291
+ model = VisionTransformer(
292
+ patch_size=patch_size,
293
+ embed_dim=768,
294
+ depth=12,
295
+ num_heads=12,
296
+ mlp_ratio=4,
297
+ qkv_bias=True,
298
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
299
+ **kwargs
300
+ )
301
+ return model
302
+
303
+
304
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
305
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
306
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
307
+ def norm_cdf(x):
308
+ # Computes standard normal cumulative distribution function
309
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
310
+
311
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
312
+ warnings.warn(
313
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
314
+ "The distribution of values may be incorrect.",
315
+ stacklevel=2,
316
+ )
317
+
318
+ with torch.no_grad():
319
+ # Values are generated by using a truncated uniform distribution and
320
+ # then using the inverse CDF for the normal distribution.
321
+ # Get upper and lower cdf values
322
+ l = norm_cdf((a - mean) / std)
323
+ u = norm_cdf((b - mean) / std)
324
+
325
+ # Uniformly fill tensor with values from [l, u], then translate to
326
+ # [2l-1, 2u-1].
327
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
328
+
329
+ # Use inverse cdf transform for normal distribution to get truncated
330
+ # standard normal
331
+ tensor.erfinv_()
332
+
333
+ # Transform to proper mean, std
334
+ tensor.mul_(std * math.sqrt(2.0))
335
+ tensor.add_(mean)
336
+
337
+ # Clamp to ensure it's in the proper range
338
+ tensor.clamp_(min=a, max=b)
339
+ return tensor
340
+
341
+
342
+ def _trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
343
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
344
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
345
+
346
+
347
+ def _drop_path(x, drop_prob: float = 0.0, training: bool = False):
348
+ if drop_prob == 0.0 or not training:
349
+ return x
350
+ keep_prob = 1 - drop_prob
351
+ shape = (x.shape[0],) + (1,) * (
352
+ x.ndim - 1
353
+ ) # work with diff dim tensors, not just 2D ConvNets
354
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
355
+ random_tensor.floor_() # binarize
356
+ output = x.div(keep_prob) * random_tensor
357
+ return output
scripts/inference.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scripts.constants import CLASS_NAMES
2
+
3
+
4
+ def infer(model, input_files):
5
+ for p in input_files:
6
+ print("Processing", p, "...")
7
+ probs = model(p)
8
+ result_str = "Result -- " + " / ".join(
9
+ [f"{name}: {probs[i].item():.4f}" for i, name in enumerate(CLASS_NAMES)]
10
+ )
11
+ print(result_str + "\n")
scripts/preprocessor.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from torchstain.base.normalizers.he_normalizer import HENormalizer
7
+ from torchstain.torch.utils import cov, percentile
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import to_pil_image
10
+
11
+
12
+ def preprocessor(pretrained=False, normalizer=None):
13
+ if pretrained:
14
+ mean = (0.485, 0.456, 0.406)
15
+ std = (0.229, 0.224, 0.225)
16
+ else:
17
+ mean = (0.5, 0.5, 0.5)
18
+ std = (0.5, 0.5, 0.5)
19
+
20
+ preprocess = transforms.Compose(
21
+ [
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=mean, std=std),
27
+ ]
28
+ )
29
+
30
+ return preprocess
31
+
32
+
33
+ """
34
+ Source code ported from: https://github.com/schaugf/HEnorm_python
35
+ Original implementation: https://github.com/mitkovetta/staining-normalization
36
+ """
37
+
38
+
39
+ class TorchMacenkoNormalizer(HENormalizer):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ self.HERef = torch.tensor(
44
+ [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
45
+ )
46
+ self.maxCRef = torch.tensor([1.9705, 1.0308])
47
+
48
+ # Avoid using deprecated torch.lstsq (since 1.9.0)
49
+ self.updated_lstsq = hasattr(torch.linalg, "lstsq")
50
+
51
+ def __convert_rgb2od(self, I, Io, beta):
52
+ I = I.permute(1, 2, 0)
53
+
54
+ # calculate optical density
55
+ OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
56
+
57
+ # remove transparent pixels
58
+ ODhat = OD[~torch.any(OD < beta, dim=1)]
59
+
60
+ return OD, ODhat
61
+
62
+ def __find_HE(self, ODhat, eigvecs, alpha):
63
+ # project on the plane spanned by the eigenvectors corresponding to the two
64
+ # largest eigenvalues
65
+ That = torch.matmul(ODhat, eigvecs)
66
+ phi = torch.atan2(That[:, 1], That[:, 0])
67
+ # print(phi.size())
68
+
69
+ minPhi = percentile(phi, alpha)
70
+ maxPhi = percentile(phi, 100 - alpha)
71
+
72
+ vMin = torch.matmul(
73
+ eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
74
+ ).unsqueeze(1)
75
+ vMax = torch.matmul(
76
+ eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
77
+ ).unsqueeze(1)
78
+
79
+ # a heuristic to make the vector corresponding to hematoxylin first and the
80
+ # one corresponding to eosin second
81
+ HE = torch.where(
82
+ vMin[0] > vMax[0],
83
+ torch.cat((vMin, vMax), dim=1),
84
+ torch.cat((vMax, vMin), dim=1),
85
+ )
86
+
87
+ return HE
88
+
89
+ def __find_concentration(self, OD, HE):
90
+ # rows correspond to channels (RGB), columns to OD values
91
+ Y = OD.T
92
+
93
+ # determine concentrations of the individual stains
94
+ if not self.updated_lstsq:
95
+ return torch.lstsq(Y, HE)[0][:2]
96
+
97
+ return torch.linalg.lstsq(HE, Y)[0]
98
+
99
+ def __compute_matrices(self, I, Io, alpha, beta):
100
+ OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
101
+
102
+ # compute eigenvectors
103
+ _, eigvecs = torch.linalg.eigh(cov(ODhat.T))
104
+ eigvecs = eigvecs[:, [1, 2]]
105
+
106
+ HE = self.__find_HE(ODhat, eigvecs, alpha)
107
+
108
+ C = self.__find_concentration(OD, HE)
109
+ maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
110
+
111
+ return HE, C, maxC
112
+
113
+ def fit(self, I, Io=240, alpha=1, beta=0.15):
114
+ HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
115
+
116
+ self.HERef = HE
117
+ self.maxCRef = maxC
118
+
119
+ def normalize(
120
+ self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
121
+ ):
122
+ """Normalize staining appearence of H&E stained images
123
+
124
+ Example use:
125
+ see test.py
126
+
127
+ Input:
128
+ I: RGB input image: tensor of shape [C, H, W] and type uint8
129
+ Io: (optional) transmitted light intensity
130
+ alpha: percentile
131
+ beta: transparency threshold
132
+ stains: if true, return also H & E components
133
+
134
+ Output:
135
+ Inorm: normalized image
136
+ H: hematoxylin image
137
+ E: eosin image
138
+
139
+ Reference:
140
+ A method for normalizing histology slides for quantitative analysis. M.
141
+ Macenko et al., ISBI 2009
142
+ """
143
+
144
+ c, h, w = I.shape
145
+
146
+ HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
147
+
148
+ # normalize stain concentrations
149
+ C *= (self.maxCRef / maxC).unsqueeze(-1)
150
+
151
+ # recreate the image using reference mixing matrix
152
+ Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
153
+ Inorm = torch.clip(Inorm, 0, 255)
154
+
155
+ Inorm = Inorm.reshape(c, h, w).float() / 255.0
156
+ Inorm = torch.clip(Inorm, 0.0, 1.0)
157
+
158
+ H, E = None, None
159
+
160
+ if stains:
161
+ H = torch.mul(
162
+ Io,
163
+ torch.exp(
164
+ torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
165
+ ),
166
+ )
167
+ H[H > 255] = 255
168
+ H = H.T.reshape(h, w, c).int()
169
+
170
+ E = torch.mul(
171
+ Io,
172
+ torch.exp(
173
+ torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
174
+ ),
175
+ )
176
+ E[E > 255] = 255
177
+ E = E.T.reshape(h, w, c).int()
178
+
179
+ return Inorm, H, E
180
+
181
+
182
+ class MacenkoNormalizer:
183
+ def __init__(self, target_path=None, prob=1):
184
+ self.transform_before_macenko = transforms.Compose(
185
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
186
+ )
187
+ self.normalizer = TorchMacenkoNormalizer()
188
+
189
+ ext = os.path.splitext(target_path)[1].lower()
190
+ if ext in [".jpg", ".jpeg", ".png"]:
191
+ target = Image.open(target_path)
192
+ self.normalizer.fit(self.transform_before_macenko(target))
193
+ elif ext in [".pt"]:
194
+ target = torch.load(target_path)
195
+ self.normalizer.HERef = target["HERef"]
196
+ self.normalizer.maxCRef = target["maxCRef"]
197
+
198
+ else:
199
+ raise ValueError(f"Invalid extension: {ext}")
200
+ self.prob = prob
201
+
202
+ def __call__(self, image):
203
+ t_to_transform = self.transform_before_macenko(image)
204
+ try:
205
+ image_macenko, _, _ = self.normalizer.normalize(
206
+ I=t_to_transform, stains=False, form="chw", dtype="float"
207
+ )
208
+ if torch.any(torch.isnan(image_macenko)):
209
+ return image
210
+ else:
211
+ image_macenko = to_pil_image(image_macenko)
212
+ return image_macenko
213
+ except Exception as e:
214
+ if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
215
+ pass
216
+ else:
217
+ print(str(e))
218
+ return image
scripts/wsi_utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from openslide import OpenSlide
5
+
6
+
7
+ def extract_tissue_patch_coords(
8
+ wsi_path: str,
9
+ patch_size: int = 256,
10
+ step_size: int = 256,
11
+ downsample_threshold: float = 64,
12
+ threshold: int = 8,
13
+ max_val: int = 255,
14
+ median_kernel: int = 7,
15
+ close_size: int = 4,
16
+ min_effective_area_factor: float = 100, # multiplied by (ref_area)^2
17
+ ref_area: int = 512,
18
+ min_hole_area_factor: float = 16, # multiplied by (ref_area)^2
19
+ max_n_holes: int = 8,
20
+ )-> list:
21
+ """
22
+ Extract patches from the full-resolution image whose centers fall within tissue regions.
23
+
24
+ Process:
25
+ 1. Open the WSI.
26
+ 2. Select a segmentation level and compute a binary mask.
27
+ 3. Find contours and holes from the mask and filter them using effective area criteria.
28
+ 4. Scale the external contours and holes to full resolution.
29
+ 5. Slide a window over the full-resolution image and extract patches if the center is in tissue.
30
+
31
+ Returns:
32
+ A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches.
33
+ """
34
+ slide = OpenSlide(wsi_path)
35
+ full_width, full_height = slide.level_dimensions[0]
36
+
37
+ seg_level, scale = select_segmentation_level(slide, downsample_threshold)
38
+ binary_mask = compute_segmentation_mask(
39
+ slide, seg_level, threshold, max_val, median_kernel, close_size
40
+ )
41
+
42
+ # Compute thresholds for effective area and hole area
43
+ effective_area_thresh = min_effective_area_factor * (
44
+ ref_area**2 / (scale[0] * scale[1])
45
+ )
46
+ hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1]))
47
+
48
+ ext_contours, holes_list = filter_contours_and_holes(
49
+ binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes
50
+ )
51
+ if not ext_contours:
52
+ raise ValueError("No valid tissue contours found.")
53
+
54
+ tissue_contours = scale_contours(ext_contours, scale)
55
+ scaled_holes = [scale_contours(holes, scale) for holes in holes_list]
56
+
57
+ coords = []
58
+ for y in range(0, full_height - patch_size + 1, step_size):
59
+ for x in range(0, full_width - patch_size + 1, step_size):
60
+ center_x = x + patch_size // 2
61
+ center_y = y + patch_size // 2
62
+ if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes):
63
+ continue
64
+ coords.append((x, y))
65
+
66
+ if not coords:
67
+ raise ValueError("No available patches")
68
+ return coords
69
+
70
+
71
+ def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64):
72
+ """
73
+ Select a segmentation level whose downsample factor is at least the specified threshold.
74
+
75
+ Returns:
76
+ level (int): Chosen level index.
77
+ scale (tuple): Downsample factors (sx, sy) for that level.
78
+ """
79
+ level = slide.get_best_level_for_downsample(downsample_threshold)
80
+ ds = slide.level_downsamples[level]
81
+ if not isinstance(ds, (tuple, list)):
82
+ ds = (ds, ds)
83
+ return level, ds
84
+
85
+
86
+ def compute_segmentation_mask(
87
+ slide: OpenSlide,
88
+ level: int,
89
+ threshold: int = 20,
90
+ max_val: int = 255,
91
+ median_kernel: int = 7,
92
+ close_size: int = 4,
93
+ ):
94
+ """
95
+ Compute a binary mask for tissue segmentation at the specified level.
96
+
97
+ Process:
98
+ - Read the image at the given level and convert to RGB.
99
+ - Convert the image to HSV and extract the saturation channel.
100
+ - Apply median blur.
101
+ - Apply binary thresholding (either fixed or Otsu).
102
+ - Apply morphological closing.
103
+
104
+ Returns:
105
+ binary (ndarray): Binary mask image.
106
+ """
107
+ img = np.array(
108
+ slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")
109
+ )
110
+ hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
111
+ sat = hsv[:, :, 1]
112
+ blurred = cv2.medianBlur(sat, median_kernel)
113
+ _, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY)
114
+ if close_size > 0:
115
+ kernel = np.ones((close_size, close_size), np.uint8)
116
+ binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
117
+ return binary
118
+
119
+
120
+ def filter_contours_and_holes(
121
+ binary_mask: np.ndarray,
122
+ min_effective_area: float,
123
+ min_hole_area: float,
124
+ max_n_holes: int,
125
+ ):
126
+ """
127
+ Find contours from the binary mask and filter them based on effective area.
128
+
129
+ For each external contour (one with no parent), identify child contours (holes),
130
+ sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area.
131
+ The effective area is computed as the area of the external contour minus the sum of areas
132
+ of the selected holes. Only contours with effective area above min_effective_area are retained.
133
+
134
+ Returns:
135
+ filtered_contours (list): List of external contours (numpy arrays).
136
+ holes_list (list): Corresponding list of lists of hole contours.
137
+ """
138
+ contours, hierarchy = cv2.findContours(
139
+ binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
140
+ )
141
+ if hierarchy is None:
142
+ return [], []
143
+ hierarchy = hierarchy[0] # shape: (N, 4)
144
+ filtered_contours = []
145
+ holes_list = []
146
+ for idx, h in enumerate(hierarchy):
147
+ if h[3] != -1:
148
+ continue # Only external contours
149
+ ext_cont = contours[idx]
150
+ ext_area = cv2.contourArea(ext_cont)
151
+ # Find child contours (holes)
152
+ hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx]
153
+ # Sort holes by area descending and keep up to max_n_holes
154
+ sorted_holes = sorted(
155
+ [contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True
156
+ )
157
+ selected_holes = [
158
+ hole
159
+ for hole in sorted_holes[:max_n_holes]
160
+ if cv2.contourArea(hole) > min_hole_area
161
+ ]
162
+ total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes)
163
+ effective_area = ext_area - total_hole_area
164
+ if effective_area > min_effective_area:
165
+ filtered_contours.append(ext_cont)
166
+ holes_list.append(selected_holes)
167
+ return filtered_contours, holes_list
168
+
169
+
170
+ def scale_contours(contours: list, scale: tuple) -> list:
171
+ """
172
+ Scale contour coordinates by the provided scale factors.
173
+
174
+ Args:
175
+ contours: List of contours (each a numpy array of points).
176
+ scale: Tuple (sx, sy) for scaling.
177
+
178
+ Returns:
179
+ List of scaled contours.
180
+ """
181
+ scaled = []
182
+ for cont in contours:
183
+ scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32))
184
+ return scaled
185
+
186
+
187
+ def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool:
188
+ """
189
+ Check if point (x, y) lies within any external contour and not inside its corresponding holes.
190
+
191
+ For each external contour in ext_contours (paired with holes_list),
192
+ if the point is inside the contour and not inside any of its holes, return True.
193
+ """
194
+ for cont, holes in zip(ext_contours, holes_list):
195
+ if cv2.pointPolygonTest(cont, (x, y), False) >= 0:
196
+ inside_hole = False
197
+ for hole in holes:
198
+ if cv2.pointPolygonTest(hole, (x, y), False) >= 0:
199
+ inside_hole = True
200
+ break
201
+ if not inside_hole:
202
+ return True
203
+ return False
204
+
205
+
206
+ def tile(x: torch.Tensor, size: int):
207
+ C, H, W = x.shape[-3:]
208
+
209
+ pad_h = (size - H % size) % size
210
+ pad_w = (size - W % size) % size
211
+ if pad_h > 0 or pad_w > 0:
212
+ x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
213
+
214
+ nh, nw = x.size(2) // size, x.size(3) // size
215
+ return (
216
+ x.view(-1, C, nh, size, nw, size)
217
+ .permute(0, 2, 4, 1, 3, 5)
218
+ .reshape(-1, C, size, size)
219
+ )