편명장/님/(myeongjang.pyeon)
commited on
Commit
·
287a683
1
Parent(s):
4083dcc
initial commit
Browse files- .gitignore +172 -0
- LICENSE +57 -0
- README.md +35 -5
- assets/exaonepath_v1.png +0 -0
- configs/inference.yaml +26 -0
- configs/metadata.json +50 -0
- models/.gitkeep +0 -0
- models/exaonepath_v1.0.0_msi.pt +3 -0
- models/macenko_param.pt +3 -0
- requirements.txt +13 -0
- scripts/__init__.py +0 -0
- scripts/aggregator.py +334 -0
- scripts/constants.py +1 -0
- scripts/dataset.py +47 -0
- scripts/exaonepath.py +92 -0
- scripts/feature_extractor.py +357 -0
- scripts/inference.py +11 -0
- scripts/preprocessor.py +218 -0
- scripts/wsi_utils.py +219 -0
.gitignore
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# PyPI configuration file
|
171 |
+
.pypirc
|
172 |
+
|
LICENSE
CHANGED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EXAONEPath AI Model License Agreement 1.0 - NC
|
2 |
+
|
3 |
+
This License Agreement (“Agreement”) is entered into between you (“Licensee”) and LG Management Development Institute Co., Ltd. (“Licensor”), governing the use of the EXAONEPath AI Model (“Model”). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
|
4 |
+
|
5 |
+
1. Definitions
|
6 |
+
1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
|
7 |
+
1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
|
8 |
+
1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
|
9 |
+
1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
|
10 |
+
1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
|
11 |
+
|
12 |
+
2. License Grant
|
13 |
+
2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
|
14 |
+
a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
|
15 |
+
b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
|
16 |
+
c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
|
17 |
+
d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
|
18 |
+
2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
|
19 |
+
|
20 |
+
3. Restrictions
|
21 |
+
3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
|
22 |
+
3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
|
23 |
+
3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
|
24 |
+
3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
|
25 |
+
a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
|
26 |
+
b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
|
27 |
+
c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
|
28 |
+
d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
|
29 |
+
|
30 |
+
4. Ownership
|
31 |
+
4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
|
32 |
+
4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
|
33 |
+
4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
|
34 |
+
|
35 |
+
5. No Warranty
|
36 |
+
5.1 “As-Is” Basis: The Model, Derivatives, and Output are provided on an “as-is” and “as-available” basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
|
37 |
+
5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licensee’s requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
|
38 |
+
5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
|
39 |
+
|
40 |
+
6. Limitation of Liability
|
41 |
+
6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
|
42 |
+
6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
|
43 |
+
|
44 |
+
7. Termination
|
45 |
+
7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licensee’s rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
|
46 |
+
7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
|
47 |
+
7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
|
48 |
+
|
49 |
+
8. Governing Law
|
50 |
+
8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
|
51 |
+
8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
|
52 |
+
|
53 |
+
9. Alterations
|
54 |
+
9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensor’s website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
|
55 |
+
9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
|
56 |
+
|
57 |
+
By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
|
README.md
CHANGED
@@ -1,5 +1,35 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
MSI classification of CRC tumors using EXAONEPath 1.0.0 - a patch-level foundation model for pathology
|
3 |
+
|
4 |
+
# Model Overview
|
5 |
+
This model serves as a reference for predicting MSI status using CRC (colorectal cancer) tumor images as input. When the model receives an H&E-stained whole slide image as input, it removes artifacts observed in the image and extracts only tissue-related objects. These objects are then reconstructed into a set of tiles with a size of 256 by 256 pixels at an mpp (micron per pixel) of 0.5.
|
6 |
+
|
7 |
+
The tiles pass through the EXAONEPath v1.0 patch-level foundation model (https://huggingface.co/LGAI-EXAONE/EXAONEPath), which converts them into a set of features. These features are then integrated into a slide-level feature representation through an aggregator. Finally, a linear classifier predicts the MSI status (MSS or MSI-H/L).
|
8 |
+
|
9 |
+
The model achieves an average performance of AUROC 0.93 on TCGA-COAD + TCGA-READ data and 0.84 on in-house data.
|
10 |
+
|
11 |
+
This open-source release aims to demonstrate that the combination of EXAONEPath and the aggregator can effectively perform pathological tasks. It is hoped that this source code will serve as an important reference not only for CRC MSI prediction but also as an image-based solution for various disease-related problems, including molecular subtyping, tumor subtyping, and mutation prediction.
|
12 |
+
|
13 |
+
This open-source code is designed to be compatible with the MONAI framework (https://monai.io/), and users must check the accompanying license before use.
|
14 |
+
|
15 |
+

|
16 |
+
|
17 |
+
# Input and Output Formats
|
18 |
+
The input is a `.svs` file containing whole slide image.
|
19 |
+
The output is an array with probabilities for each of the two classes (MSS / MSI(H/L)).
|
20 |
+
|
21 |
+
# Execute Inference
|
22 |
+
The inference can be executed as follows
|
23 |
+
1. Copy your `.svs` files into `samples` directory
|
24 |
+
2. Run inference
|
25 |
+
```
|
26 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json configs/logging.conf
|
27 |
+
```
|
28 |
+
|
29 |
+
## Contact
|
30 |
+
LG AI Research Technical Support: <a href="mailto:[email protected]">[email protected]</a>
|
31 |
+
|
32 |
+
# License
|
33 |
+
Copyright (c) LG AI Research
|
34 |
+
|
35 |
+
The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](LICENSE).
|
assets/exaonepath_v1.png
ADDED
![]() |
configs/inference.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
imports:
|
2 |
+
- $import os
|
3 |
+
- $import glob
|
4 |
+
- $import torch
|
5 |
+
- $import scripts
|
6 |
+
- $import scripts.inference
|
7 |
+
|
8 |
+
device: '$torch.device("cuda" if torch.cuda.is_available() else "cpu")'
|
9 |
+
|
10 |
+
model_config:
|
11 |
+
_target_: scripts.exaonepath.EXAONEPathV1Downstream
|
12 |
+
step_size: 256
|
13 |
+
patch_size: 256
|
14 |
+
macenko: true
|
15 |
+
device: '$@device'
|
16 |
+
|
17 |
+
model: '$@model_config'
|
18 |
+
|
19 |
+
input_dir: 'samples'
|
20 |
+
input_files: '$sorted(glob.glob(@input_dir+''/*.svs''))'
|
21 |
+
|
22 |
+
root_dir: '$os.path.dirname(os.path.dirname(scripts.__file__))'
|
23 |
+
|
24 |
+
inference:
|
25 |
+
- [email protected]_state_dict(torch.load(os.path.join(@root_dir, "models/exaonepath_v1.0.0_msi.pt")))
|
26 |
+
- $scripts.inference.infer(@model, @input_files)
|
configs/metadata.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "0.1.0",
|
3 |
+
"changelog": {
|
4 |
+
"0.1.0": "Initial release"
|
5 |
+
},
|
6 |
+
"monai_version": "1.5",
|
7 |
+
"pytorch_version": "2.6.0",
|
8 |
+
"numpy_version": "1.26.4",
|
9 |
+
"required_packages_version": {
|
10 |
+
"torchvision": "0.21.0",
|
11 |
+
"openslide-bin": "4.0.0.6",
|
12 |
+
"openslide-python": "1.4.1",
|
13 |
+
"pandas": "2.2.3",
|
14 |
+
"torchstain": "1.4.1",
|
15 |
+
"opencv-python-headless": "4.11.0.86",
|
16 |
+
"einops": "0.8.1"
|
17 |
+
},
|
18 |
+
"name": "MSI predictor from WSI using EXAONEPath v1.0",
|
19 |
+
"task": "MSI Classification of Colorectal Cancer (CRC) Tumors",
|
20 |
+
"description": "MSI classification of CRC tumors using EXAONEPath 1.0.0 - a patch-level foundation model for pathology",
|
21 |
+
"authors": "LG AI Research",
|
22 |
+
"copyright": "Copyright (c) LG AI Research",
|
23 |
+
"data_source": "Whole Slide Images",
|
24 |
+
"data_type": "float32",
|
25 |
+
"intended_use": "Research only",
|
26 |
+
"network_data_format": {
|
27 |
+
"inputs": {
|
28 |
+
"image": {
|
29 |
+
"type": "image",
|
30 |
+
"format": "svs",
|
31 |
+
"modality": "pathology"
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"outputs": {
|
35 |
+
"pred": {
|
36 |
+
"type": "probabilities",
|
37 |
+
"format": "classes",
|
38 |
+
"num_channels": 2,
|
39 |
+
"spatial_shape": [2],
|
40 |
+
"dtype": "float32",
|
41 |
+
"value_range": [0, 1],
|
42 |
+
"is_patch_data": false,
|
43 |
+
"channel_def": {
|
44 |
+
"0": "MSS",
|
45 |
+
"1": "MSI"
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
models/.gitkeep
ADDED
File without changes
|
models/exaonepath_v1.0.0_msi.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c7c3199335922de3c156ede12e2d655b23ec988529be18611981832d6ab8095
|
3 |
+
size 625193090
|
models/macenko_param.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74a6e178f1d3ea897152e544857eb880045a8ffb2be1e05b1639386083e68435
|
3 |
+
size 999
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
|
3 |
+
torch==2.6.0+cu124
|
4 |
+
torchvision==0.21.0+cu124
|
5 |
+
numpy==1.26.4
|
6 |
+
opencv-python-headless==4.11.0.86
|
7 |
+
openslide-bin==4.0.0.6
|
8 |
+
openslide-python==1.4.1
|
9 |
+
monai-weekly[ignite,pyyaml]==1.5.dev2506
|
10 |
+
pandas==2.2.3
|
11 |
+
torchstain==1.4.1
|
12 |
+
fire==0.7.0
|
13 |
+
einops==0.8.1
|
scripts/__init__.py
ADDED
File without changes
|
scripts/aggregator.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from math import log, pi
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from einops.layers.torch import Rearrange, Reduce
|
8 |
+
from torch import einsum, nn
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
return val if exists(val) else d
|
17 |
+
|
18 |
+
|
19 |
+
def cache_fn(f):
|
20 |
+
cache = dict()
|
21 |
+
|
22 |
+
@wraps(f)
|
23 |
+
def cached_fn(*args, _cache=True, key=None, **kwargs):
|
24 |
+
if not _cache:
|
25 |
+
return f(*args, **kwargs)
|
26 |
+
nonlocal cache
|
27 |
+
if key in cache:
|
28 |
+
return cache[key]
|
29 |
+
result = f(*args, **kwargs)
|
30 |
+
cache[key] = result
|
31 |
+
return result
|
32 |
+
|
33 |
+
return cached_fn
|
34 |
+
|
35 |
+
|
36 |
+
def fourier_encode(x, max_freq, num_bands=4):
|
37 |
+
x = x.unsqueeze(-1)
|
38 |
+
device, dtype, orig_x = x.device, x.dtype, x
|
39 |
+
|
40 |
+
scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype)
|
41 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
|
42 |
+
|
43 |
+
x = x * scales * pi
|
44 |
+
x = torch.cat([x.sin(), x.cos()], dim=-1)
|
45 |
+
x = torch.cat((x, orig_x), dim=-1)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class PreNorm(nn.Module):
|
50 |
+
def __init__(self, dim, fn, context_dim=None):
|
51 |
+
super().__init__()
|
52 |
+
self.fn = fn
|
53 |
+
self.norm = nn.LayerNorm(dim)
|
54 |
+
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
|
55 |
+
|
56 |
+
def forward(self, x, **kwargs):
|
57 |
+
x = self.norm(x)
|
58 |
+
|
59 |
+
if exists(self.norm_context):
|
60 |
+
context = kwargs["context"]
|
61 |
+
normed_context = self.norm_context(context)
|
62 |
+
kwargs.update(context=normed_context)
|
63 |
+
|
64 |
+
return self.fn(x, **kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
class GEGLU(nn.Module):
|
68 |
+
def forward(self, x):
|
69 |
+
x, gates = x.chunk(2, dim=-1)
|
70 |
+
return x * F.gelu(gates)
|
71 |
+
|
72 |
+
|
73 |
+
class FeedForward(nn.Module):
|
74 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
75 |
+
super().__init__()
|
76 |
+
self.net = nn.Sequential(
|
77 |
+
nn.Linear(dim, dim * mult * 2),
|
78 |
+
GEGLU(),
|
79 |
+
nn.Linear(dim * mult, dim),
|
80 |
+
nn.Dropout(dropout),
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return self.net(x)
|
85 |
+
|
86 |
+
|
87 |
+
class Attention(nn.Module):
|
88 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
89 |
+
super().__init__()
|
90 |
+
inner_dim = dim_head * heads
|
91 |
+
context_dim = default(context_dim, query_dim)
|
92 |
+
|
93 |
+
self.scale = dim_head**-0.5
|
94 |
+
self.heads = heads
|
95 |
+
|
96 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
97 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
98 |
+
|
99 |
+
self.dropout = nn.Dropout(dropout)
|
100 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
101 |
+
|
102 |
+
def forward(self, x, context=None, mask=None):
|
103 |
+
h = self.heads
|
104 |
+
|
105 |
+
q = self.to_q(x)
|
106 |
+
context = default(context, x)
|
107 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
108 |
+
|
109 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
110 |
+
|
111 |
+
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
112 |
+
|
113 |
+
if exists(mask):
|
114 |
+
mask = rearrange(mask, "b ... -> b (...)")
|
115 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
116 |
+
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
117 |
+
sim.masked_fill_(~mask, max_neg_value)
|
118 |
+
|
119 |
+
# attention, what we cannot get enough of
|
120 |
+
attn = sim.softmax(dim=-1)
|
121 |
+
attn = self.dropout(attn)
|
122 |
+
|
123 |
+
out = einsum("b i j, b j d -> b i d", attn, v)
|
124 |
+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
125 |
+
return self.to_out(out)
|
126 |
+
|
127 |
+
|
128 |
+
class Perceiver(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*,
|
132 |
+
num_freq_bands,
|
133 |
+
depth,
|
134 |
+
max_freq,
|
135 |
+
input_channels=3,
|
136 |
+
input_axis=2,
|
137 |
+
num_latents=512,
|
138 |
+
latent_dim=512,
|
139 |
+
cross_heads=1,
|
140 |
+
latent_heads=8,
|
141 |
+
cross_dim_head=64,
|
142 |
+
latent_dim_head=64,
|
143 |
+
num_classes=1000,
|
144 |
+
attn_dropout=0.0,
|
145 |
+
ff_dropout=0.0,
|
146 |
+
weight_tie_layers=False,
|
147 |
+
fourier_encode_data=True,
|
148 |
+
self_per_cross_attn=1,
|
149 |
+
final_classifier_head=True,
|
150 |
+
pool="mean",
|
151 |
+
latent_init=None,
|
152 |
+
):
|
153 |
+
"""The shape of the final attention mechanism will be:
|
154 |
+
depth * (cross attention -> self_per_cross_attn * self attention)
|
155 |
+
|
156 |
+
Args:
|
157 |
+
num_freq_bands: Number of freq bands, with original value (2 * K + 1)
|
158 |
+
depth: Depth of net.
|
159 |
+
max_freq: Maximum frequency, hyperparameter depending on how
|
160 |
+
fine the data is.
|
161 |
+
freq_base: Base for the frequency
|
162 |
+
input_channels: Number of channels for each token of the input.
|
163 |
+
input_axis: Number of axes for input data (2 for images, 3 for video)
|
164 |
+
num_latents: Number of latents, or induced set points, or centroids.
|
165 |
+
Different papers giving it different names.
|
166 |
+
latent_dim: Latent dimension.
|
167 |
+
cross_heads: Number of heads for cross attention. Paper said 1.
|
168 |
+
latent_heads: Number of heads for latent self attention, 8.
|
169 |
+
cross_dim_head: Number of dimensions per cross attention head.
|
170 |
+
latent_dim_head: Number of dimensions per latent self attention head.
|
171 |
+
num_classes: Output number of classes.
|
172 |
+
attn_dropout: Attention dropout
|
173 |
+
ff_dropout: Feedforward dropout
|
174 |
+
weight_tie_layers: Whether to weight tie layers (optional).
|
175 |
+
fourier_encode_data: Whether to auto-fourier encode the data, using
|
176 |
+
the input_axis given. defaults to True, but can be turned off
|
177 |
+
if you are fourier encoding the data yourself.
|
178 |
+
self_per_cross_attn: Number of self attention blocks per cross attn.
|
179 |
+
final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
|
180 |
+
"""
|
181 |
+
super().__init__()
|
182 |
+
self.input_axis = input_axis
|
183 |
+
self.max_freq = max_freq
|
184 |
+
self.num_freq_bands = num_freq_bands
|
185 |
+
self.self_per_cross_attn = self_per_cross_attn
|
186 |
+
|
187 |
+
self.fourier_encode_data = fourier_encode_data
|
188 |
+
fourier_channels = (
|
189 |
+
(input_axis * ((num_freq_bands * 2) + 1)) * 2 if fourier_encode_data else 0
|
190 |
+
)
|
191 |
+
input_dim = fourier_channels + input_channels
|
192 |
+
|
193 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
194 |
+
if latent_init is not None:
|
195 |
+
latent_init_feat = torch.load(latent_init)
|
196 |
+
if type(latent_init_feat) != torch.Tensor:
|
197 |
+
latent_init_feat = torch.Tensor(latent_init_feat)
|
198 |
+
if len(latent_init_feat.shape) == 3:
|
199 |
+
latent_init_feat = latent_init_feat[0]
|
200 |
+
with torch.no_grad():
|
201 |
+
self.latents.copy_(latent_init_feat)
|
202 |
+
print(f"load latent feature: , {latent_init}")
|
203 |
+
|
204 |
+
get_cross_attn = lambda: PreNorm(
|
205 |
+
latent_dim,
|
206 |
+
Attention(
|
207 |
+
latent_dim,
|
208 |
+
input_dim,
|
209 |
+
heads=cross_heads,
|
210 |
+
dim_head=cross_dim_head,
|
211 |
+
dropout=attn_dropout,
|
212 |
+
),
|
213 |
+
context_dim=input_dim,
|
214 |
+
)
|
215 |
+
get_cross_ff = lambda: PreNorm(
|
216 |
+
latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
|
217 |
+
)
|
218 |
+
get_latent_attn = lambda: PreNorm(
|
219 |
+
latent_dim,
|
220 |
+
Attention(
|
221 |
+
latent_dim,
|
222 |
+
heads=latent_heads,
|
223 |
+
dim_head=latent_dim_head,
|
224 |
+
dropout=attn_dropout,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
get_latent_ff = lambda: PreNorm(
|
228 |
+
latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
|
229 |
+
)
|
230 |
+
|
231 |
+
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
|
232 |
+
cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)
|
233 |
+
)
|
234 |
+
|
235 |
+
self.layers = nn.ModuleList([])
|
236 |
+
for i in range(depth):
|
237 |
+
should_cache = i > 0 and weight_tie_layers
|
238 |
+
cache_args = {"_cache": should_cache}
|
239 |
+
|
240 |
+
self_attns = nn.ModuleList([])
|
241 |
+
|
242 |
+
for block_ind in range(self_per_cross_attn):
|
243 |
+
self_attns.append(
|
244 |
+
nn.ModuleList(
|
245 |
+
[
|
246 |
+
get_latent_attn(**cache_args, key=block_ind),
|
247 |
+
get_latent_ff(**cache_args, key=block_ind),
|
248 |
+
]
|
249 |
+
)
|
250 |
+
)
|
251 |
+
if self_per_cross_attn == 0:
|
252 |
+
self_attns.append(get_latent_ff(**cache_args, key=block_ind))
|
253 |
+
|
254 |
+
self.layers.append(
|
255 |
+
nn.ModuleList(
|
256 |
+
[
|
257 |
+
get_cross_attn(**cache_args),
|
258 |
+
get_cross_ff(**cache_args),
|
259 |
+
self_attns,
|
260 |
+
]
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
+
if final_classifier_head:
|
265 |
+
if pool == "cat":
|
266 |
+
self.to_logits = nn.Sequential(
|
267 |
+
Rearrange("b n d -> b (n d)"),
|
268 |
+
nn.LayerNorm(num_latents * latent_dim),
|
269 |
+
nn.Linear(num_latents * latent_dim, num_classes),
|
270 |
+
)
|
271 |
+
elif pool == "mlp":
|
272 |
+
self.to_logits = nn.Sequential(
|
273 |
+
Reduce("b n d -> b d", "mean"),
|
274 |
+
nn.LayerNorm(latent_dim),
|
275 |
+
nn.Linear(latent_dim, latent_dim),
|
276 |
+
nn.ReLU(),
|
277 |
+
nn.LayerNorm(latent_dim),
|
278 |
+
nn.Linear(latent_dim, num_classes),
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
self.to_logits = nn.Sequential(
|
282 |
+
Reduce("b n d -> b d", pool),
|
283 |
+
nn.LayerNorm(latent_dim),
|
284 |
+
nn.Linear(latent_dim, num_classes),
|
285 |
+
)
|
286 |
+
|
287 |
+
def forward(self, h, label=None, mask=None, pretrain=False, coords=None):
|
288 |
+
b, *axis, _, device, dtype = *h.shape, h.device, h.dtype
|
289 |
+
assert (
|
290 |
+
len(axis) == self.input_axis
|
291 |
+
), "input data must have the right number of axis"
|
292 |
+
|
293 |
+
if self.fourier_encode_data:
|
294 |
+
# calculate fourier encoded positions in the range of [-1, 1], for all axis
|
295 |
+
|
296 |
+
# axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
|
297 |
+
# pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
|
298 |
+
# enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
|
299 |
+
# enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
|
300 |
+
# enc_pos = repeat(enc_pos, '... -> b ...', b = b)
|
301 |
+
|
302 |
+
enc_pos = fourier_encode(coords, self.max_freq, self.num_freq_bands)
|
303 |
+
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
|
304 |
+
|
305 |
+
h = torch.cat((h, enc_pos), dim=-1)
|
306 |
+
|
307 |
+
# concat to channels of data and flatten axis
|
308 |
+
|
309 |
+
h = rearrange(h, "b ... d -> b (...) d")
|
310 |
+
|
311 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
312 |
+
|
313 |
+
# layers
|
314 |
+
|
315 |
+
for cross_attn, cross_ff, self_attns in self.layers:
|
316 |
+
x = cross_attn(x, context=h, mask=mask) + x
|
317 |
+
x = cross_ff(x) + x
|
318 |
+
|
319 |
+
if self.self_per_cross_attn > 0:
|
320 |
+
for self_attn, self_ff in self_attns:
|
321 |
+
x = self_attn(x) + x
|
322 |
+
x = self_ff(x) + x
|
323 |
+
else:
|
324 |
+
x = self_attns[0](x) + x
|
325 |
+
# allow for fetching embeddings
|
326 |
+
|
327 |
+
if pretrain:
|
328 |
+
return x.mean(dim=1)
|
329 |
+
|
330 |
+
# to logits
|
331 |
+
logits = self.to_logits(x)
|
332 |
+
Y_hat = torch.topk(logits, 1, dim=1)[1]
|
333 |
+
Y_prob = F.softmax(logits, dim=1)
|
334 |
+
return logits, Y_prob, Y_hat
|
scripts/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CLASS_NAMES = ["MSS", "MSI"]
|
scripts/dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from openslide import OpenSlide
|
4 |
+
from scripts.preprocessor import MacenkoNormalizer, preprocessor
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class WSIPatchDataset(Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
coords,
|
12 |
+
wsi_path,
|
13 |
+
pretrained=False,
|
14 |
+
patch_size=256,
|
15 |
+
patch_level=0,
|
16 |
+
macenko=True,
|
17 |
+
):
|
18 |
+
self.pretrained = pretrained
|
19 |
+
self.wsi = OpenSlide(wsi_path)
|
20 |
+
self.patch_size = patch_size
|
21 |
+
self.patch_level = patch_level
|
22 |
+
|
23 |
+
if macenko:
|
24 |
+
normalizer = MacenkoNormalizer(
|
25 |
+
target_path=os.path.join(
|
26 |
+
os.path.dirname(os.path.dirname(os.path.join(__file__))),
|
27 |
+
"models",
|
28 |
+
"macenko_param.pt",
|
29 |
+
)
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
normalizer = None
|
33 |
+
|
34 |
+
self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
|
35 |
+
self.coords = coords
|
36 |
+
self.length = len(self.coords)
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return self.length
|
40 |
+
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
coord = self.coords[idx]
|
43 |
+
img = self.wsi.read_region(
|
44 |
+
coord, self.patch_level, (self.patch_size, self.patch_size)
|
45 |
+
).convert("RGB")
|
46 |
+
img = self.roi_transforms(img)
|
47 |
+
return img
|
scripts/exaonepath.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from scripts.aggregator import Perceiver
|
5 |
+
from scripts.dataset import WSIPatchDataset
|
6 |
+
from scripts.feature_extractor import vit_base
|
7 |
+
from scripts.wsi_utils import extract_tissue_patch_coords
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
|
11 |
+
class EXAONEPathV1Downstream(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, device: torch.device, step_size=256, patch_size=256, macenko=True
|
14 |
+
):
|
15 |
+
super(EXAONEPathV1Downstream, self).__init__()
|
16 |
+
self.step_size = step_size
|
17 |
+
self.patch_size = patch_size
|
18 |
+
self.macenko = macenko
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
self.feature_extractor = vit_base()
|
22 |
+
self.feature_extractor = self.feature_extractor
|
23 |
+
self.feature_extractor = self.feature_extractor.to(self.device)
|
24 |
+
self.feature_extractor.eval()
|
25 |
+
|
26 |
+
self.agg_model = Perceiver(
|
27 |
+
input_channels=768,
|
28 |
+
input_axis=1,
|
29 |
+
num_freq_bands=6,
|
30 |
+
max_freq=10.0,
|
31 |
+
depth=6,
|
32 |
+
num_latents=256,
|
33 |
+
latent_dim=512,
|
34 |
+
cross_heads=1,
|
35 |
+
latent_heads=8,
|
36 |
+
cross_dim_head=64,
|
37 |
+
latent_dim_head=64,
|
38 |
+
num_classes=2,
|
39 |
+
fourier_encode_data=False,
|
40 |
+
self_per_cross_attn=2,
|
41 |
+
pool="mean",
|
42 |
+
)
|
43 |
+
self.agg_model.to(self.device)
|
44 |
+
self.agg_model.eval()
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
|
48 |
+
# Extract patches
|
49 |
+
coords = extract_tissue_patch_coords(
|
50 |
+
svs_path, patch_size=self.patch_size, step_size=self.step_size
|
51 |
+
)
|
52 |
+
|
53 |
+
# Extract patch-level features
|
54 |
+
self.feature_extractor.eval()
|
55 |
+
patch_dataset = WSIPatchDataset(
|
56 |
+
coords=coords,
|
57 |
+
wsi_path=svs_path,
|
58 |
+
pretrained=True,
|
59 |
+
macenko=self.macenko,
|
60 |
+
patch_size=self.patch_size,
|
61 |
+
)
|
62 |
+
patch_loader = DataLoader(
|
63 |
+
dataset=patch_dataset,
|
64 |
+
batch_size=feature_extractor_batch_size,
|
65 |
+
num_workers=(
|
66 |
+
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
|
67 |
+
),
|
68 |
+
pin_memory=self.device.type == "cuda",
|
69 |
+
)
|
70 |
+
features_list = []
|
71 |
+
for count, patches in enumerate(patch_loader):
|
72 |
+
print(
|
73 |
+
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
|
74 |
+
end="\r",
|
75 |
+
)
|
76 |
+
patches = patches.to(self.device, non_blocking=True)
|
77 |
+
|
78 |
+
feature = self.feature_extractor(patches) # [B, 1024]
|
79 |
+
feature /= feature.norm(dim=-1, keepdim=True) # use normalized feature
|
80 |
+
feature = feature.to("cpu", non_blocking=True)
|
81 |
+
features_list.append(feature)
|
82 |
+
print("")
|
83 |
+
print("Feature extraction finished")
|
84 |
+
|
85 |
+
features = torch.cat(features_list)
|
86 |
+
|
87 |
+
# Aggregate features
|
88 |
+
self.agg_model.eval()
|
89 |
+
logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device))
|
90 |
+
probs = Y_prob[0].cpu()
|
91 |
+
|
92 |
+
return probs
|
scripts/feature_extractor.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class DropPath(nn.Module):
|
10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
11 |
+
|
12 |
+
def __init__(self, drop_prob=None):
|
13 |
+
super(DropPath, self).__init__()
|
14 |
+
self.drop_prob = drop_prob
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return _drop_path(x, self.drop_prob, self.training)
|
18 |
+
|
19 |
+
|
20 |
+
class Mlp(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
in_features,
|
24 |
+
hidden_features=None,
|
25 |
+
out_features=None,
|
26 |
+
act_layer=nn.GELU,
|
27 |
+
drop=0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
out_features = out_features or in_features
|
31 |
+
hidden_features = hidden_features or in_features
|
32 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
33 |
+
self.act = act_layer()
|
34 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
35 |
+
self.drop = nn.Dropout(drop)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.fc1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
x = self.fc2(x)
|
42 |
+
x = self.drop(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class Attention(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
dim,
|
50 |
+
num_heads=8,
|
51 |
+
qkv_bias=False,
|
52 |
+
qk_scale=None,
|
53 |
+
attn_drop=0.0,
|
54 |
+
proj_drop=0.0,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.num_heads = num_heads
|
58 |
+
head_dim = dim // num_heads
|
59 |
+
self.scale = qk_scale or head_dim**-0.5
|
60 |
+
|
61 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
63 |
+
self.proj = nn.Linear(dim, dim)
|
64 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
B, N, C = x.shape
|
68 |
+
qkv = (
|
69 |
+
self.qkv(x)
|
70 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
71 |
+
.permute(2, 0, 3, 1, 4)
|
72 |
+
)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
80 |
+
x = self.proj(x)
|
81 |
+
x = self.proj_drop(x)
|
82 |
+
return x, attn
|
83 |
+
|
84 |
+
|
85 |
+
class Block(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim,
|
89 |
+
num_heads,
|
90 |
+
mlp_ratio=4.0,
|
91 |
+
qkv_bias=False,
|
92 |
+
qk_scale=None,
|
93 |
+
drop=0.0,
|
94 |
+
attn_drop=0.0,
|
95 |
+
drop_path=0.0,
|
96 |
+
act_layer=nn.GELU,
|
97 |
+
norm_layer=nn.LayerNorm,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.norm1 = norm_layer(dim)
|
101 |
+
self.attn = Attention(
|
102 |
+
dim,
|
103 |
+
num_heads=num_heads,
|
104 |
+
qkv_bias=qkv_bias,
|
105 |
+
qk_scale=qk_scale,
|
106 |
+
attn_drop=attn_drop,
|
107 |
+
proj_drop=drop,
|
108 |
+
)
|
109 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
112 |
+
self.mlp = Mlp(
|
113 |
+
in_features=dim,
|
114 |
+
hidden_features=mlp_hidden_dim,
|
115 |
+
act_layer=act_layer,
|
116 |
+
drop=drop,
|
117 |
+
)
|
118 |
+
|
119 |
+
def forward(self, x, return_attention=False):
|
120 |
+
y, attn = self.attn(self.norm1(x))
|
121 |
+
if return_attention:
|
122 |
+
return attn
|
123 |
+
x = x + self.drop_path(y)
|
124 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class PatchEmbed(nn.Module):
|
129 |
+
"""Image to Patch Embedding"""
|
130 |
+
|
131 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
132 |
+
super().__init__()
|
133 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
134 |
+
self.img_size = img_size
|
135 |
+
self.patch_size = patch_size
|
136 |
+
self.num_patches = num_patches
|
137 |
+
|
138 |
+
self.proj = nn.Conv2d(
|
139 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
B, C, H, W = x.shape
|
144 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class VisionTransformer(nn.Module):
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
img_size=[224],
|
152 |
+
patch_size=16,
|
153 |
+
in_chans=3,
|
154 |
+
num_classes=0,
|
155 |
+
embed_dim=768,
|
156 |
+
depth=12,
|
157 |
+
num_heads=12,
|
158 |
+
mlp_ratio=4.0,
|
159 |
+
qkv_bias=False,
|
160 |
+
qk_scale=None,
|
161 |
+
drop_rate=0.0,
|
162 |
+
attn_drop_rate=0.0,
|
163 |
+
drop_path_rate=0.0,
|
164 |
+
norm_layer=nn.LayerNorm,
|
165 |
+
**kwargs
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_features = self.embed_dim = embed_dim
|
169 |
+
|
170 |
+
self.patch_embed = PatchEmbed(
|
171 |
+
img_size=img_size[0],
|
172 |
+
patch_size=patch_size,
|
173 |
+
in_chans=in_chans,
|
174 |
+
embed_dim=embed_dim,
|
175 |
+
)
|
176 |
+
num_patches = self.patch_embed.num_patches
|
177 |
+
|
178 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
179 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
180 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
181 |
+
|
182 |
+
dpr = [
|
183 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
184 |
+
] # stochastic depth decay rule
|
185 |
+
self.blocks = nn.ModuleList(
|
186 |
+
[
|
187 |
+
Block(
|
188 |
+
dim=embed_dim,
|
189 |
+
num_heads=num_heads,
|
190 |
+
mlp_ratio=mlp_ratio,
|
191 |
+
qkv_bias=qkv_bias,
|
192 |
+
qk_scale=qk_scale,
|
193 |
+
drop=drop_rate,
|
194 |
+
attn_drop=attn_drop_rate,
|
195 |
+
drop_path=dpr[i],
|
196 |
+
norm_layer=norm_layer,
|
197 |
+
)
|
198 |
+
for i in range(depth)
|
199 |
+
]
|
200 |
+
)
|
201 |
+
self.norm = norm_layer(embed_dim)
|
202 |
+
|
203 |
+
# Classifier head
|
204 |
+
self.head = (
|
205 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
206 |
+
)
|
207 |
+
|
208 |
+
_trunc_normal_(self.pos_embed, std=0.02)
|
209 |
+
_trunc_normal_(self.cls_token, std=0.02)
|
210 |
+
self.apply(self._init_weights)
|
211 |
+
|
212 |
+
def _init_weights(self, m):
|
213 |
+
if isinstance(m, nn.Linear):
|
214 |
+
_trunc_normal_(m.weight, std=0.02)
|
215 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
216 |
+
nn.init.constant_(m.bias, 0)
|
217 |
+
elif isinstance(m, nn.LayerNorm):
|
218 |
+
nn.init.constant_(m.bias, 0)
|
219 |
+
nn.init.constant_(m.weight, 1.0)
|
220 |
+
|
221 |
+
def interpolate_pos_encoding(self, x, w, h):
|
222 |
+
npatch = x.shape[1] - 1
|
223 |
+
N = self.pos_embed.shape[1] - 1
|
224 |
+
if npatch == N and w == h:
|
225 |
+
return self.pos_embed
|
226 |
+
class_pos_embed = self.pos_embed[:, 0]
|
227 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
228 |
+
dim = x.shape[-1]
|
229 |
+
w0 = w // self.patch_embed.patch_size
|
230 |
+
h0 = h // self.patch_embed.patch_size
|
231 |
+
# we add a small number to avoid floating point error in the interpolation
|
232 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
233 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
234 |
+
patch_pos_embed = nn.functional.interpolate(
|
235 |
+
patch_pos_embed.reshape(
|
236 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
237 |
+
).permute(0, 3, 1, 2),
|
238 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
239 |
+
# size=(int(w0), int(h0)),
|
240 |
+
mode="bicubic",
|
241 |
+
)
|
242 |
+
assert (
|
243 |
+
int(w0) == patch_pos_embed.shape[-2]
|
244 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
245 |
+
)
|
246 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
247 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
248 |
+
|
249 |
+
def prepare_tokens(self, x):
|
250 |
+
B, nc, w, h = x.shape
|
251 |
+
x = self.patch_embed(x) # patch linear embedding
|
252 |
+
|
253 |
+
# add the [CLS] token to the embed patch tokens
|
254 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
255 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
256 |
+
|
257 |
+
# add positional encoding to each token
|
258 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
259 |
+
|
260 |
+
return self.pos_drop(x)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
x = self.prepare_tokens(x)
|
264 |
+
for blk in self.blocks:
|
265 |
+
x = blk(x)
|
266 |
+
x = self.norm(x)
|
267 |
+
# print(x.type())
|
268 |
+
return x[:, 0]
|
269 |
+
|
270 |
+
def get_last_selfattention(self, x):
|
271 |
+
x = self.prepare_tokens(x)
|
272 |
+
for i, blk in enumerate(self.blocks):
|
273 |
+
if i < len(self.blocks) - 1:
|
274 |
+
x = blk(x)
|
275 |
+
else:
|
276 |
+
# return attention of the last block
|
277 |
+
return blk(x, return_attention=True)
|
278 |
+
|
279 |
+
def get_intermediate_layers(self, x, n=1):
|
280 |
+
x = self.prepare_tokens(x)
|
281 |
+
# we return the output tokens from the `n` last blocks
|
282 |
+
output = []
|
283 |
+
for i, blk in enumerate(self.blocks):
|
284 |
+
x = blk(x)
|
285 |
+
if len(self.blocks) - i <= n:
|
286 |
+
output.append(self.norm(x))
|
287 |
+
return output
|
288 |
+
|
289 |
+
|
290 |
+
def vit_base(patch_size=16, **kwargs):
|
291 |
+
model = VisionTransformer(
|
292 |
+
patch_size=patch_size,
|
293 |
+
embed_dim=768,
|
294 |
+
depth=12,
|
295 |
+
num_heads=12,
|
296 |
+
mlp_ratio=4,
|
297 |
+
qkv_bias=True,
|
298 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
299 |
+
**kwargs
|
300 |
+
)
|
301 |
+
return model
|
302 |
+
|
303 |
+
|
304 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
305 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
306 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
307 |
+
def norm_cdf(x):
|
308 |
+
# Computes standard normal cumulative distribution function
|
309 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
310 |
+
|
311 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
312 |
+
warnings.warn(
|
313 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
314 |
+
"The distribution of values may be incorrect.",
|
315 |
+
stacklevel=2,
|
316 |
+
)
|
317 |
+
|
318 |
+
with torch.no_grad():
|
319 |
+
# Values are generated by using a truncated uniform distribution and
|
320 |
+
# then using the inverse CDF for the normal distribution.
|
321 |
+
# Get upper and lower cdf values
|
322 |
+
l = norm_cdf((a - mean) / std)
|
323 |
+
u = norm_cdf((b - mean) / std)
|
324 |
+
|
325 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
326 |
+
# [2l-1, 2u-1].
|
327 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
328 |
+
|
329 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
330 |
+
# standard normal
|
331 |
+
tensor.erfinv_()
|
332 |
+
|
333 |
+
# Transform to proper mean, std
|
334 |
+
tensor.mul_(std * math.sqrt(2.0))
|
335 |
+
tensor.add_(mean)
|
336 |
+
|
337 |
+
# Clamp to ensure it's in the proper range
|
338 |
+
tensor.clamp_(min=a, max=b)
|
339 |
+
return tensor
|
340 |
+
|
341 |
+
|
342 |
+
def _trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
343 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
344 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
345 |
+
|
346 |
+
|
347 |
+
def _drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
348 |
+
if drop_prob == 0.0 or not training:
|
349 |
+
return x
|
350 |
+
keep_prob = 1 - drop_prob
|
351 |
+
shape = (x.shape[0],) + (1,) * (
|
352 |
+
x.ndim - 1
|
353 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
354 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
355 |
+
random_tensor.floor_() # binarize
|
356 |
+
output = x.div(keep_prob) * random_tensor
|
357 |
+
return output
|
scripts/inference.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scripts.constants import CLASS_NAMES
|
2 |
+
|
3 |
+
|
4 |
+
def infer(model, input_files):
|
5 |
+
for p in input_files:
|
6 |
+
print("Processing", p, "...")
|
7 |
+
probs = model(p)
|
8 |
+
result_str = "Result -- " + " / ".join(
|
9 |
+
[f"{name}: {probs[i].item():.4f}" for i, name in enumerate(CLASS_NAMES)]
|
10 |
+
)
|
11 |
+
print(result_str + "\n")
|
scripts/preprocessor.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchstain.base.normalizers.he_normalizer import HENormalizer
|
7 |
+
from torchstain.torch.utils import cov, percentile
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms.functional import to_pil_image
|
10 |
+
|
11 |
+
|
12 |
+
def preprocessor(pretrained=False, normalizer=None):
|
13 |
+
if pretrained:
|
14 |
+
mean = (0.485, 0.456, 0.406)
|
15 |
+
std = (0.229, 0.224, 0.225)
|
16 |
+
else:
|
17 |
+
mean = (0.5, 0.5, 0.5)
|
18 |
+
std = (0.5, 0.5, 0.5)
|
19 |
+
|
20 |
+
preprocess = transforms.Compose(
|
21 |
+
[
|
22 |
+
transforms.Resize(256),
|
23 |
+
transforms.CenterCrop(224),
|
24 |
+
transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(mean=mean, std=std),
|
27 |
+
]
|
28 |
+
)
|
29 |
+
|
30 |
+
return preprocess
|
31 |
+
|
32 |
+
|
33 |
+
"""
|
34 |
+
Source code ported from: https://github.com/schaugf/HEnorm_python
|
35 |
+
Original implementation: https://github.com/mitkovetta/staining-normalization
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
class TorchMacenkoNormalizer(HENormalizer):
|
40 |
+
def __init__(self):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.HERef = torch.tensor(
|
44 |
+
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
|
45 |
+
)
|
46 |
+
self.maxCRef = torch.tensor([1.9705, 1.0308])
|
47 |
+
|
48 |
+
# Avoid using deprecated torch.lstsq (since 1.9.0)
|
49 |
+
self.updated_lstsq = hasattr(torch.linalg, "lstsq")
|
50 |
+
|
51 |
+
def __convert_rgb2od(self, I, Io, beta):
|
52 |
+
I = I.permute(1, 2, 0)
|
53 |
+
|
54 |
+
# calculate optical density
|
55 |
+
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
|
56 |
+
|
57 |
+
# remove transparent pixels
|
58 |
+
ODhat = OD[~torch.any(OD < beta, dim=1)]
|
59 |
+
|
60 |
+
return OD, ODhat
|
61 |
+
|
62 |
+
def __find_HE(self, ODhat, eigvecs, alpha):
|
63 |
+
# project on the plane spanned by the eigenvectors corresponding to the two
|
64 |
+
# largest eigenvalues
|
65 |
+
That = torch.matmul(ODhat, eigvecs)
|
66 |
+
phi = torch.atan2(That[:, 1], That[:, 0])
|
67 |
+
# print(phi.size())
|
68 |
+
|
69 |
+
minPhi = percentile(phi, alpha)
|
70 |
+
maxPhi = percentile(phi, 100 - alpha)
|
71 |
+
|
72 |
+
vMin = torch.matmul(
|
73 |
+
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
|
74 |
+
).unsqueeze(1)
|
75 |
+
vMax = torch.matmul(
|
76 |
+
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
|
77 |
+
).unsqueeze(1)
|
78 |
+
|
79 |
+
# a heuristic to make the vector corresponding to hematoxylin first and the
|
80 |
+
# one corresponding to eosin second
|
81 |
+
HE = torch.where(
|
82 |
+
vMin[0] > vMax[0],
|
83 |
+
torch.cat((vMin, vMax), dim=1),
|
84 |
+
torch.cat((vMax, vMin), dim=1),
|
85 |
+
)
|
86 |
+
|
87 |
+
return HE
|
88 |
+
|
89 |
+
def __find_concentration(self, OD, HE):
|
90 |
+
# rows correspond to channels (RGB), columns to OD values
|
91 |
+
Y = OD.T
|
92 |
+
|
93 |
+
# determine concentrations of the individual stains
|
94 |
+
if not self.updated_lstsq:
|
95 |
+
return torch.lstsq(Y, HE)[0][:2]
|
96 |
+
|
97 |
+
return torch.linalg.lstsq(HE, Y)[0]
|
98 |
+
|
99 |
+
def __compute_matrices(self, I, Io, alpha, beta):
|
100 |
+
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
|
101 |
+
|
102 |
+
# compute eigenvectors
|
103 |
+
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
|
104 |
+
eigvecs = eigvecs[:, [1, 2]]
|
105 |
+
|
106 |
+
HE = self.__find_HE(ODhat, eigvecs, alpha)
|
107 |
+
|
108 |
+
C = self.__find_concentration(OD, HE)
|
109 |
+
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
|
110 |
+
|
111 |
+
return HE, C, maxC
|
112 |
+
|
113 |
+
def fit(self, I, Io=240, alpha=1, beta=0.15):
|
114 |
+
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
|
115 |
+
|
116 |
+
self.HERef = HE
|
117 |
+
self.maxCRef = maxC
|
118 |
+
|
119 |
+
def normalize(
|
120 |
+
self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
|
121 |
+
):
|
122 |
+
"""Normalize staining appearence of H&E stained images
|
123 |
+
|
124 |
+
Example use:
|
125 |
+
see test.py
|
126 |
+
|
127 |
+
Input:
|
128 |
+
I: RGB input image: tensor of shape [C, H, W] and type uint8
|
129 |
+
Io: (optional) transmitted light intensity
|
130 |
+
alpha: percentile
|
131 |
+
beta: transparency threshold
|
132 |
+
stains: if true, return also H & E components
|
133 |
+
|
134 |
+
Output:
|
135 |
+
Inorm: normalized image
|
136 |
+
H: hematoxylin image
|
137 |
+
E: eosin image
|
138 |
+
|
139 |
+
Reference:
|
140 |
+
A method for normalizing histology slides for quantitative analysis. M.
|
141 |
+
Macenko et al., ISBI 2009
|
142 |
+
"""
|
143 |
+
|
144 |
+
c, h, w = I.shape
|
145 |
+
|
146 |
+
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
|
147 |
+
|
148 |
+
# normalize stain concentrations
|
149 |
+
C *= (self.maxCRef / maxC).unsqueeze(-1)
|
150 |
+
|
151 |
+
# recreate the image using reference mixing matrix
|
152 |
+
Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
|
153 |
+
Inorm = torch.clip(Inorm, 0, 255)
|
154 |
+
|
155 |
+
Inorm = Inorm.reshape(c, h, w).float() / 255.0
|
156 |
+
Inorm = torch.clip(Inorm, 0.0, 1.0)
|
157 |
+
|
158 |
+
H, E = None, None
|
159 |
+
|
160 |
+
if stains:
|
161 |
+
H = torch.mul(
|
162 |
+
Io,
|
163 |
+
torch.exp(
|
164 |
+
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
|
165 |
+
),
|
166 |
+
)
|
167 |
+
H[H > 255] = 255
|
168 |
+
H = H.T.reshape(h, w, c).int()
|
169 |
+
|
170 |
+
E = torch.mul(
|
171 |
+
Io,
|
172 |
+
torch.exp(
|
173 |
+
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
|
174 |
+
),
|
175 |
+
)
|
176 |
+
E[E > 255] = 255
|
177 |
+
E = E.T.reshape(h, w, c).int()
|
178 |
+
|
179 |
+
return Inorm, H, E
|
180 |
+
|
181 |
+
|
182 |
+
class MacenkoNormalizer:
|
183 |
+
def __init__(self, target_path=None, prob=1):
|
184 |
+
self.transform_before_macenko = transforms.Compose(
|
185 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
|
186 |
+
)
|
187 |
+
self.normalizer = TorchMacenkoNormalizer()
|
188 |
+
|
189 |
+
ext = os.path.splitext(target_path)[1].lower()
|
190 |
+
if ext in [".jpg", ".jpeg", ".png"]:
|
191 |
+
target = Image.open(target_path)
|
192 |
+
self.normalizer.fit(self.transform_before_macenko(target))
|
193 |
+
elif ext in [".pt"]:
|
194 |
+
target = torch.load(target_path)
|
195 |
+
self.normalizer.HERef = target["HERef"]
|
196 |
+
self.normalizer.maxCRef = target["maxCRef"]
|
197 |
+
|
198 |
+
else:
|
199 |
+
raise ValueError(f"Invalid extension: {ext}")
|
200 |
+
self.prob = prob
|
201 |
+
|
202 |
+
def __call__(self, image):
|
203 |
+
t_to_transform = self.transform_before_macenko(image)
|
204 |
+
try:
|
205 |
+
image_macenko, _, _ = self.normalizer.normalize(
|
206 |
+
I=t_to_transform, stains=False, form="chw", dtype="float"
|
207 |
+
)
|
208 |
+
if torch.any(torch.isnan(image_macenko)):
|
209 |
+
return image
|
210 |
+
else:
|
211 |
+
image_macenko = to_pil_image(image_macenko)
|
212 |
+
return image_macenko
|
213 |
+
except Exception as e:
|
214 |
+
if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
|
215 |
+
pass
|
216 |
+
else:
|
217 |
+
print(str(e))
|
218 |
+
return image
|
scripts/wsi_utils.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from openslide import OpenSlide
|
5 |
+
|
6 |
+
|
7 |
+
def extract_tissue_patch_coords(
|
8 |
+
wsi_path: str,
|
9 |
+
patch_size: int = 256,
|
10 |
+
step_size: int = 256,
|
11 |
+
downsample_threshold: float = 64,
|
12 |
+
threshold: int = 8,
|
13 |
+
max_val: int = 255,
|
14 |
+
median_kernel: int = 7,
|
15 |
+
close_size: int = 4,
|
16 |
+
min_effective_area_factor: float = 100, # multiplied by (ref_area)^2
|
17 |
+
ref_area: int = 512,
|
18 |
+
min_hole_area_factor: float = 16, # multiplied by (ref_area)^2
|
19 |
+
max_n_holes: int = 8,
|
20 |
+
)-> list:
|
21 |
+
"""
|
22 |
+
Extract patches from the full-resolution image whose centers fall within tissue regions.
|
23 |
+
|
24 |
+
Process:
|
25 |
+
1. Open the WSI.
|
26 |
+
2. Select a segmentation level and compute a binary mask.
|
27 |
+
3. Find contours and holes from the mask and filter them using effective area criteria.
|
28 |
+
4. Scale the external contours and holes to full resolution.
|
29 |
+
5. Slide a window over the full-resolution image and extract patches if the center is in tissue.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches.
|
33 |
+
"""
|
34 |
+
slide = OpenSlide(wsi_path)
|
35 |
+
full_width, full_height = slide.level_dimensions[0]
|
36 |
+
|
37 |
+
seg_level, scale = select_segmentation_level(slide, downsample_threshold)
|
38 |
+
binary_mask = compute_segmentation_mask(
|
39 |
+
slide, seg_level, threshold, max_val, median_kernel, close_size
|
40 |
+
)
|
41 |
+
|
42 |
+
# Compute thresholds for effective area and hole area
|
43 |
+
effective_area_thresh = min_effective_area_factor * (
|
44 |
+
ref_area**2 / (scale[0] * scale[1])
|
45 |
+
)
|
46 |
+
hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1]))
|
47 |
+
|
48 |
+
ext_contours, holes_list = filter_contours_and_holes(
|
49 |
+
binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes
|
50 |
+
)
|
51 |
+
if not ext_contours:
|
52 |
+
raise ValueError("No valid tissue contours found.")
|
53 |
+
|
54 |
+
tissue_contours = scale_contours(ext_contours, scale)
|
55 |
+
scaled_holes = [scale_contours(holes, scale) for holes in holes_list]
|
56 |
+
|
57 |
+
coords = []
|
58 |
+
for y in range(0, full_height - patch_size + 1, step_size):
|
59 |
+
for x in range(0, full_width - patch_size + 1, step_size):
|
60 |
+
center_x = x + patch_size // 2
|
61 |
+
center_y = y + patch_size // 2
|
62 |
+
if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes):
|
63 |
+
continue
|
64 |
+
coords.append((x, y))
|
65 |
+
|
66 |
+
if not coords:
|
67 |
+
raise ValueError("No available patches")
|
68 |
+
return coords
|
69 |
+
|
70 |
+
|
71 |
+
def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64):
|
72 |
+
"""
|
73 |
+
Select a segmentation level whose downsample factor is at least the specified threshold.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
level (int): Chosen level index.
|
77 |
+
scale (tuple): Downsample factors (sx, sy) for that level.
|
78 |
+
"""
|
79 |
+
level = slide.get_best_level_for_downsample(downsample_threshold)
|
80 |
+
ds = slide.level_downsamples[level]
|
81 |
+
if not isinstance(ds, (tuple, list)):
|
82 |
+
ds = (ds, ds)
|
83 |
+
return level, ds
|
84 |
+
|
85 |
+
|
86 |
+
def compute_segmentation_mask(
|
87 |
+
slide: OpenSlide,
|
88 |
+
level: int,
|
89 |
+
threshold: int = 20,
|
90 |
+
max_val: int = 255,
|
91 |
+
median_kernel: int = 7,
|
92 |
+
close_size: int = 4,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Compute a binary mask for tissue segmentation at the specified level.
|
96 |
+
|
97 |
+
Process:
|
98 |
+
- Read the image at the given level and convert to RGB.
|
99 |
+
- Convert the image to HSV and extract the saturation channel.
|
100 |
+
- Apply median blur.
|
101 |
+
- Apply binary thresholding (either fixed or Otsu).
|
102 |
+
- Apply morphological closing.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
binary (ndarray): Binary mask image.
|
106 |
+
"""
|
107 |
+
img = np.array(
|
108 |
+
slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")
|
109 |
+
)
|
110 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
111 |
+
sat = hsv[:, :, 1]
|
112 |
+
blurred = cv2.medianBlur(sat, median_kernel)
|
113 |
+
_, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY)
|
114 |
+
if close_size > 0:
|
115 |
+
kernel = np.ones((close_size, close_size), np.uint8)
|
116 |
+
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
117 |
+
return binary
|
118 |
+
|
119 |
+
|
120 |
+
def filter_contours_and_holes(
|
121 |
+
binary_mask: np.ndarray,
|
122 |
+
min_effective_area: float,
|
123 |
+
min_hole_area: float,
|
124 |
+
max_n_holes: int,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Find contours from the binary mask and filter them based on effective area.
|
128 |
+
|
129 |
+
For each external contour (one with no parent), identify child contours (holes),
|
130 |
+
sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area.
|
131 |
+
The effective area is computed as the area of the external contour minus the sum of areas
|
132 |
+
of the selected holes. Only contours with effective area above min_effective_area are retained.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
filtered_contours (list): List of external contours (numpy arrays).
|
136 |
+
holes_list (list): Corresponding list of lists of hole contours.
|
137 |
+
"""
|
138 |
+
contours, hierarchy = cv2.findContours(
|
139 |
+
binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
|
140 |
+
)
|
141 |
+
if hierarchy is None:
|
142 |
+
return [], []
|
143 |
+
hierarchy = hierarchy[0] # shape: (N, 4)
|
144 |
+
filtered_contours = []
|
145 |
+
holes_list = []
|
146 |
+
for idx, h in enumerate(hierarchy):
|
147 |
+
if h[3] != -1:
|
148 |
+
continue # Only external contours
|
149 |
+
ext_cont = contours[idx]
|
150 |
+
ext_area = cv2.contourArea(ext_cont)
|
151 |
+
# Find child contours (holes)
|
152 |
+
hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx]
|
153 |
+
# Sort holes by area descending and keep up to max_n_holes
|
154 |
+
sorted_holes = sorted(
|
155 |
+
[contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True
|
156 |
+
)
|
157 |
+
selected_holes = [
|
158 |
+
hole
|
159 |
+
for hole in sorted_holes[:max_n_holes]
|
160 |
+
if cv2.contourArea(hole) > min_hole_area
|
161 |
+
]
|
162 |
+
total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes)
|
163 |
+
effective_area = ext_area - total_hole_area
|
164 |
+
if effective_area > min_effective_area:
|
165 |
+
filtered_contours.append(ext_cont)
|
166 |
+
holes_list.append(selected_holes)
|
167 |
+
return filtered_contours, holes_list
|
168 |
+
|
169 |
+
|
170 |
+
def scale_contours(contours: list, scale: tuple) -> list:
|
171 |
+
"""
|
172 |
+
Scale contour coordinates by the provided scale factors.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
contours: List of contours (each a numpy array of points).
|
176 |
+
scale: Tuple (sx, sy) for scaling.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List of scaled contours.
|
180 |
+
"""
|
181 |
+
scaled = []
|
182 |
+
for cont in contours:
|
183 |
+
scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32))
|
184 |
+
return scaled
|
185 |
+
|
186 |
+
|
187 |
+
def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool:
|
188 |
+
"""
|
189 |
+
Check if point (x, y) lies within any external contour and not inside its corresponding holes.
|
190 |
+
|
191 |
+
For each external contour in ext_contours (paired with holes_list),
|
192 |
+
if the point is inside the contour and not inside any of its holes, return True.
|
193 |
+
"""
|
194 |
+
for cont, holes in zip(ext_contours, holes_list):
|
195 |
+
if cv2.pointPolygonTest(cont, (x, y), False) >= 0:
|
196 |
+
inside_hole = False
|
197 |
+
for hole in holes:
|
198 |
+
if cv2.pointPolygonTest(hole, (x, y), False) >= 0:
|
199 |
+
inside_hole = True
|
200 |
+
break
|
201 |
+
if not inside_hole:
|
202 |
+
return True
|
203 |
+
return False
|
204 |
+
|
205 |
+
|
206 |
+
def tile(x: torch.Tensor, size: int):
|
207 |
+
C, H, W = x.shape[-3:]
|
208 |
+
|
209 |
+
pad_h = (size - H % size) % size
|
210 |
+
pad_w = (size - W % size) % size
|
211 |
+
if pad_h > 0 or pad_w > 0:
|
212 |
+
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
|
213 |
+
|
214 |
+
nh, nw = x.size(2) // size, x.size(3) // size
|
215 |
+
return (
|
216 |
+
x.view(-1, C, nh, size, nw, size)
|
217 |
+
.permute(0, 2, 4, 1, 3, 5)
|
218 |
+
.reshape(-1, C, size, size)
|
219 |
+
)
|