Setup instructions, docstrings
Browse files- README.md +45 -1
- app.py +35 -3
- environment.yml +247 -18
- labelmap.py +2 -0
- requirements.txt +74 -0
- train.py +69 -15
README.md
CHANGED
@@ -1 +1,45 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diabetic Retinopathy Detection with AI
|
2 |
+
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
### Gradio app environment
|
6 |
+
|
7 |
+
Install from pip requirements file:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
conda create -n retinopathy_app python=3.10
|
11 |
+
conda activate retinopathy_app
|
12 |
+
pip install -r requirements.txt
|
13 |
+
python app.py
|
14 |
+
```
|
15 |
+
|
16 |
+
Install manually:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
pip install pytorch --index-url https://download.pytorch.org/whl/cpu
|
20 |
+
pip install gradio
|
21 |
+
pip install transformers
|
22 |
+
```
|
23 |
+
|
24 |
+
### Training environment
|
25 |
+
|
26 |
+
Create conda environment from YAML:
|
27 |
+
```bash
|
28 |
+
mamba env create -n retinopathy_train -f environment.yml
|
29 |
+
```
|
30 |
+
|
31 |
+
Download the data from [Kaggle](https://www.kaggle.com/competitions/diabetic-retinopathy-detection/data) or use kaggle API:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install kaggle
|
35 |
+
kaggle competitions download -c diabetic-retinopathy-detection
|
36 |
+
mkdir retinopathy_data/
|
37 |
+
unzip diabetic-retinopathy-detection.zip -d retinopathy_data/
|
38 |
+
```
|
39 |
+
|
40 |
+
Launch training:
|
41 |
+
```bash
|
42 |
+
conda activate retinopathy_train
|
43 |
+
python train.py
|
44 |
+
```
|
45 |
+
The trained model will be put into `lightning_logs/`.
|
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from typing import Tuple, Optional, Dict, List
|
6 |
import glob
|
7 |
from collections import defaultdict
|
8 |
|
@@ -13,7 +13,10 @@ from labelmap import DR_LABELMAP
|
|
13 |
|
14 |
|
15 |
class App:
|
|
|
|
|
16 |
def __init__(self) -> None:
|
|
|
17 |
|
18 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
19 |
|
@@ -41,7 +44,7 @@ class App:
|
|
41 |
output = gr.Label(num_top_classes=len(DR_LABELMAP),
|
42 |
label="Retinopathy level prediction")
|
43 |
with gr.Column(scale=4):
|
44 |
-
gr.Markdown(":
|
46 |
with gr.Column(scale=9, min_width=100):
|
47 |
image = gr.Image(label="Retina scan")
|
@@ -66,9 +69,19 @@ class App:
|
|
66 |
self.ui = ui
|
67 |
|
68 |
def launch(self) -> None:
|
|
|
69 |
self.ui.queue().launch(share=True)
|
70 |
|
71 |
-
def predict(self, image: Optional[np.ndarray]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
if image is None:
|
73 |
return dict()
|
74 |
cls_name, prob, probs = self._infer(image)
|
@@ -79,6 +92,19 @@ class App:
|
|
79 |
return probs_dict
|
80 |
|
81 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
assert isinstance(self.model, ResNetForImageClassification)
|
83 |
|
84 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
@@ -98,6 +124,11 @@ class App:
|
|
98 |
|
99 |
@staticmethod
|
100 |
def _load_example_lists() -> Dict[int, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
103 |
|
@@ -115,6 +146,7 @@ class App:
|
|
115 |
|
116 |
|
117 |
def main():
|
|
|
118 |
app = App()
|
119 |
app.launch()
|
120 |
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from typing import Tuple, Optional, Dict, List, Dict
|
6 |
import glob
|
7 |
from collections import defaultdict
|
8 |
|
|
|
13 |
|
14 |
|
15 |
class App:
|
16 |
+
""" Demonstration of the Diabetic Retinopathy model as a Gradio app. """
|
17 |
+
|
18 |
def __init__(self) -> None:
|
19 |
+
""" Constructor. """
|
20 |
|
21 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
22 |
|
|
|
44 |
output = gr.Label(num_top_classes=len(DR_LABELMAP),
|
45 |
label="Retinopathy level prediction")
|
46 |
with gr.Column(scale=4):
|
47 |
+
gr.Markdown("")
|
48 |
with gr.Row():
|
49 |
with gr.Column(scale=9, min_width=100):
|
50 |
image = gr.Image(label="Retina scan")
|
|
|
69 |
self.ui = ui
|
70 |
|
71 |
def launch(self) -> None:
|
72 |
+
""" Launch the application, blocking. """
|
73 |
self.ui.queue().launch(share=True)
|
74 |
|
75 |
+
def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]:
|
76 |
+
""" Gradio callback for pricessing of an image.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
image (Optional[np.ndarray]): Provided image.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Dict[str, float]: Label-compatible dict.
|
83 |
+
"""
|
84 |
+
|
85 |
if image is None:
|
86 |
return dict()
|
87 |
cls_name, prob, probs = self._infer(image)
|
|
|
92 |
return probs_dict
|
93 |
|
94 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
95 |
+
""" Low-level method to perform neural network inference.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
image_chw (np.ndarray): Provided image.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Tuple[str, float, np.ndarray]:
|
102 |
+
- Most probable class name
|
103 |
+
- Probability of the most probable class name.
|
104 |
+
- Probablilities of all classes in the order of
|
105 |
+
being listed in the label map.
|
106 |
+
"""
|
107 |
+
|
108 |
assert isinstance(self.model, ResNetForImageClassification)
|
109 |
|
110 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
|
|
124 |
|
125 |
@staticmethod
|
126 |
def _load_example_lists() -> Dict[int, List[str]]:
|
127 |
+
""" Load example retina images from disk.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Dict[int, List[str]]: Dictionary of cls_id -> list of images paths.
|
131 |
+
"""
|
132 |
|
133 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
134 |
|
|
|
146 |
|
147 |
|
148 |
def main():
|
149 |
+
""" App entry point. """
|
150 |
app = App()
|
151 |
app.launch()
|
152 |
|
environment.yml
CHANGED
@@ -1,22 +1,251 @@
|
|
1 |
-
name:
|
2 |
-
|
3 |
channels:
|
4 |
-
-
|
5 |
-
- nvidia
|
6 |
- conda-forge
|
7 |
- defaults
|
8 |
-
|
9 |
dependencies:
|
10 |
-
-
|
11 |
-
-
|
12 |
-
-
|
13 |
-
-
|
14 |
-
-
|
15 |
-
-
|
16 |
-
-
|
17 |
-
-
|
18 |
-
-
|
19 |
-
-
|
20 |
-
-
|
21 |
-
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: retinopathy
|
|
|
2 |
channels:
|
3 |
+
- anaconda
|
|
|
4 |
- conda-forge
|
5 |
- defaults
|
|
|
6 |
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- aiofiles=22.1.0=py310h06a4308_0
|
10 |
+
- aiosqlite=0.18.0=py310h06a4308_0
|
11 |
+
- argon2-cffi=21.3.0=pyhd3eb1b0_0
|
12 |
+
- argon2-cffi-bindings=21.2.0=py310h7f8727e_0
|
13 |
+
- asttokens=2.0.5=pyhd3eb1b0_0
|
14 |
+
- attrs=23.1.0=py310h06a4308_0
|
15 |
+
- babel=2.11.0=py310h06a4308_0
|
16 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
17 |
+
- beautifulsoup4=4.12.2=py310h06a4308_0
|
18 |
+
- bleach=4.1.0=pyhd3eb1b0_0
|
19 |
+
- brotli-python=1.0.9=py310h6a678d5_7
|
20 |
+
- bzip2=1.0.8=h7b6447c_0
|
21 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
22 |
+
- certifi=2023.11.17=py310h06a4308_0
|
23 |
+
- cffi=1.16.0=py310h5eee18b_0
|
24 |
+
- comm=0.1.2=py310h06a4308_0
|
25 |
+
- cryptography=41.0.3=py310hdda0065_0
|
26 |
+
- debugpy=1.6.7=py310h6a678d5_0
|
27 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
28 |
+
- defusedxml=0.7.1=pyhd3eb1b0_0
|
29 |
+
- executing=0.8.3=pyhd3eb1b0_0
|
30 |
+
- ipykernel=6.25.0=py310h2f386ee_0
|
31 |
+
- ipython=8.15.0=py310h06a4308_0
|
32 |
+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
33 |
+
- jedi=0.18.1=py310h06a4308_1
|
34 |
+
- jinja2=3.1.2=py310h06a4308_0
|
35 |
+
- json5=0.9.6=pyhd3eb1b0_0
|
36 |
+
- jsonschema=4.19.2=py310h06a4308_0
|
37 |
+
- jsonschema-specifications=2023.7.1=py310h06a4308_0
|
38 |
+
- jupyter_client=8.6.0=py310h06a4308_0
|
39 |
+
- jupyter_core=5.5.0=py310h06a4308_0
|
40 |
+
- jupyter_events=0.8.0=py310h06a4308_0
|
41 |
+
- jupyter_server=2.10.0=py310h06a4308_0
|
42 |
+
- jupyter_server_fileid=0.9.0=py310h06a4308_0
|
43 |
+
- jupyter_server_terminals=0.4.4=py310h06a4308_1
|
44 |
+
- jupyter_server_ydoc=0.8.0=py310h06a4308_1
|
45 |
+
- jupyter_ydoc=0.2.4=py310h06a4308_0
|
46 |
+
- jupyterlab=3.6.3=py310h06a4308_0
|
47 |
+
- jupyterlab_pygments=0.2.2=py310h06a4308_0
|
48 |
+
- jupyterlab_server=2.25.1=py310h06a4308_0
|
49 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
50 |
+
- libffi=3.4.4=h6a678d5_0
|
51 |
+
- libgcc-ng=11.2.0=h1234567_1
|
52 |
+
- libgomp=11.2.0=h1234567_1
|
53 |
+
- libsodium=1.0.18=h7b6447c_0
|
54 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
55 |
+
- libuuid=1.41.5=h5eee18b_0
|
56 |
+
- matplotlib-inline=0.1.6=py310h06a4308_0
|
57 |
+
- mistune=2.0.4=py310h06a4308_0
|
58 |
+
- nbclassic=1.0.0=py310h06a4308_0
|
59 |
+
- nbclient=0.8.0=py310h06a4308_0
|
60 |
+
- nbconvert=7.10.0=py310h06a4308_0
|
61 |
+
- nbformat=5.9.2=py310h06a4308_0
|
62 |
+
- ncurses=6.4=h6a678d5_0
|
63 |
+
- nest-asyncio=1.5.6=py310h06a4308_0
|
64 |
+
- notebook=6.5.4=py310h06a4308_0
|
65 |
+
- notebook-shim=0.2.3=py310h06a4308_0
|
66 |
+
- openssl=3.0.12=h7f8727e_0
|
67 |
+
- overrides=7.4.0=py310h06a4308_0
|
68 |
+
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
69 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
70 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
71 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
72 |
+
- platformdirs=3.10.0=py310h06a4308_0
|
73 |
+
- prometheus_client=0.14.1=py310h06a4308_0
|
74 |
+
- prompt-toolkit=3.0.36=py310h06a4308_0
|
75 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
76 |
+
- pure_eval=0.2.2=pyhd3eb1b0_0
|
77 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
78 |
+
- pyopenssl=23.2.0=py310h06a4308_0
|
79 |
+
- pysocks=1.7.1=py310h06a4308_0
|
80 |
+
- python=3.10.13=h955ad1f_0
|
81 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
82 |
+
- python-fastjsonschema=2.16.2=py310h06a4308_0
|
83 |
+
- python-json-logger=2.0.7=py310h06a4308_0
|
84 |
+
- pytz=2023.3.post1=py310h06a4308_0
|
85 |
+
- pyyaml=6.0.1=py310h5eee18b_0
|
86 |
+
- pyzmq=25.1.0=py310h6a678d5_0
|
87 |
+
- readline=8.2=h5eee18b_0
|
88 |
+
- referencing=0.30.2=py310h06a4308_0
|
89 |
+
- requests=2.31.0=py310h06a4308_0
|
90 |
+
- rfc3339-validator=0.1.4=py310h06a4308_0
|
91 |
+
- rfc3986-validator=0.1.1=py310h06a4308_0
|
92 |
+
- rpds-py=0.10.6=py310hb02cf49_0
|
93 |
+
- send2trash=1.8.2=py310h06a4308_0
|
94 |
+
- setuptools=68.0.0=py310h06a4308_0
|
95 |
+
- six=1.16.0=pyhd3eb1b0_1
|
96 |
+
- soupsieve=2.5=py310h06a4308_0
|
97 |
+
- sqlite=3.41.2=h5eee18b_0
|
98 |
+
- stack_data=0.2.0=pyhd3eb1b0_0
|
99 |
+
- terminado=0.17.1=py310h06a4308_0
|
100 |
+
- tinycss2=1.2.1=py310h06a4308_0
|
101 |
+
- tk=8.6.12=h1ccaba5_0
|
102 |
+
- tomli=2.0.1=py310h06a4308_0
|
103 |
+
- tornado=6.3.3=py310h5eee18b_0
|
104 |
+
- webencodings=0.5.1=py310h06a4308_1
|
105 |
+
- wheel=0.41.2=py310h06a4308_0
|
106 |
+
- y-py=0.5.9=py310h52d8a92_0
|
107 |
+
- yaml=0.2.5=h7b6447c_0
|
108 |
+
- ypy-websocket=0.8.2=py310h06a4308_0
|
109 |
+
- zeromq=4.3.4=h2531618_0
|
110 |
+
- zlib=1.2.13=h5eee18b_0
|
111 |
+
- pip:
|
112 |
+
- absl-py==2.0.0
|
113 |
+
- aiobotocore==2.8.0
|
114 |
+
- aiohttp==3.9.1
|
115 |
+
- aioitertools==0.11.0
|
116 |
+
- aiosignal==1.3.1
|
117 |
+
- altair==5.2.0
|
118 |
+
- annotated-types==0.6.0
|
119 |
+
- antlr4-python3-runtime==4.9.3
|
120 |
+
- anyio==3.7.1
|
121 |
+
- arrow==1.3.0
|
122 |
+
- async-timeout==4.0.3
|
123 |
+
- backoff==2.2.1
|
124 |
+
- bitsandbytes==0.41.3
|
125 |
+
- blessed==1.20.0
|
126 |
+
- boto3==1.33.1
|
127 |
+
- botocore==1.33.1
|
128 |
+
- cachetools==5.3.2
|
129 |
+
- chardet==5.2.0
|
130 |
+
- charset-normalizer==3.3.2
|
131 |
+
- click==8.1.7
|
132 |
+
- colorama==0.4.6
|
133 |
+
- contourpy==1.2.0
|
134 |
+
- croniter==1.4.1
|
135 |
+
- cycler==0.12.1
|
136 |
+
- dateutils==0.6.12
|
137 |
+
- deepdiff==6.7.1
|
138 |
+
- docker==6.1.3
|
139 |
+
- docstring-parser==0.15
|
140 |
+
- exceptiongroup==1.2.0
|
141 |
+
- fastapi==0.104.1
|
142 |
+
- ffmpy==0.3.1
|
143 |
+
- filelock==3.13.1
|
144 |
+
- fonttools==4.46.0
|
145 |
+
- frozenlist==1.4.0
|
146 |
+
- fsspec==2023.12.1
|
147 |
+
- google-auth==2.25.1
|
148 |
+
- google-auth-oauthlib==1.1.0
|
149 |
+
- gradio==4.12.0
|
150 |
+
- gradio-client==0.8.0
|
151 |
+
- grpcio==1.59.3
|
152 |
+
- h11==0.14.0
|
153 |
+
- httpcore==1.0.2
|
154 |
+
- httpx==0.26.0
|
155 |
+
- huggingface-hub==0.19.4
|
156 |
+
- hydra-core==1.3.2
|
157 |
+
- idna==3.6
|
158 |
+
- importlib-resources==6.1.1
|
159 |
+
- inquirer==3.1.4
|
160 |
+
- itsdangerous==2.1.2
|
161 |
+
- jmespath==1.0.1
|
162 |
+
- jsonargparse==4.27.1
|
163 |
+
- kiwisolver==1.4.5
|
164 |
+
- lightning==2.1.2
|
165 |
+
- lightning-api-access==0.0.5
|
166 |
+
- lightning-cloud==0.5.52
|
167 |
+
- lightning-fabric==2.1.2
|
168 |
+
- lightning-utilities==0.10.0
|
169 |
+
- line-profiler==4.1.2
|
170 |
+
- markdown==3.5.1
|
171 |
+
- markdown-it-py==3.0.0
|
172 |
+
- markupsafe==2.1.3
|
173 |
+
- matplotlib==3.8.2
|
174 |
+
- mdurl==0.1.2
|
175 |
+
- mpmath==1.3.0
|
176 |
+
- multidict==6.0.4
|
177 |
+
- networkx==3.2.1
|
178 |
+
- numpy==1.26.2
|
179 |
+
- nvidia-cublas-cu12==12.1.3.1
|
180 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
181 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
182 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
183 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
184 |
+
- nvidia-cufft-cu12==11.0.2.54
|
185 |
+
- nvidia-curand-cu12==10.3.2.106
|
186 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
187 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
188 |
+
- nvidia-nccl-cu12==2.18.1
|
189 |
+
- nvidia-nvjitlink-cu12==12.3.101
|
190 |
+
- nvidia-nvtx-cu12==12.1.105
|
191 |
+
- oauthlib==3.2.2
|
192 |
+
- omegaconf==2.3.0
|
193 |
+
- ordered-set==4.1.0
|
194 |
+
- orjson==3.9.10
|
195 |
+
- packaging==23.2
|
196 |
+
- pandas==2.1.3
|
197 |
+
- pillow==10.1.0
|
198 |
+
- protobuf==4.23.4
|
199 |
+
- psutil==5.9.6
|
200 |
+
- pyasn1==0.5.1
|
201 |
+
- pyasn1-modules==0.3.0
|
202 |
+
- pydantic==2.5.2
|
203 |
+
- pydantic-core==2.14.5
|
204 |
+
- pydub==0.25.1
|
205 |
+
- pygments==2.17.2
|
206 |
+
- pyjwt==2.8.0
|
207 |
+
- pyparsing==3.1.1
|
208 |
+
- python-editor==1.0.4
|
209 |
+
- python-multipart==0.0.6
|
210 |
+
- pytorch-lightning==2.1.2
|
211 |
+
- readchar==4.0.5
|
212 |
+
- redis==5.0.1
|
213 |
+
- regex==2023.10.3
|
214 |
+
- requests-oauthlib==1.3.1
|
215 |
+
- rich==13.7.0
|
216 |
+
- rsa==4.9
|
217 |
+
- s3fs==2023.12.1
|
218 |
+
- s3transfer==0.8.0
|
219 |
+
- safetensors==0.4.1
|
220 |
+
- semantic-version==2.10.0
|
221 |
+
- shellingham==1.5.4
|
222 |
+
- sniffio==1.3.0
|
223 |
+
- starlette==0.27.0
|
224 |
+
- starsessions==1.3.0
|
225 |
+
- sympy==1.12
|
226 |
+
- tensorboard==2.15.1
|
227 |
+
- tensorboard-data-server==0.7.2
|
228 |
+
- tensorboardx==2.6.2.2
|
229 |
+
- tokenizers==0.15.0
|
230 |
+
- tomlkit==0.12.0
|
231 |
+
- toolz==0.12.0
|
232 |
+
- torch==2.1.1
|
233 |
+
- torchmetrics==1.2.1
|
234 |
+
- torchvision==0.16.1
|
235 |
+
- tqdm==4.66.1
|
236 |
+
- traitlets==5.14.0
|
237 |
+
- transformers==4.35.2
|
238 |
+
- triton==2.1.0
|
239 |
+
- typer==0.9.0
|
240 |
+
- types-python-dateutil==2.8.19.14
|
241 |
+
- typeshed-client==2.4.0
|
242 |
+
- typing-extensions==4.9.0
|
243 |
+
- tzdata==2023.3
|
244 |
+
- urllib3==2.0.7
|
245 |
+
- uvicorn==0.24.0.post1
|
246 |
+
- wcwidth==0.2.12
|
247 |
+
- websocket-client==1.7.0
|
248 |
+
- websockets==11.0.3
|
249 |
+
- werkzeug==3.0.1
|
250 |
+
- wrapt==1.16.0
|
251 |
+
- yarl==1.9.4
|
labelmap.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
DR_LABELMAP = {
|
2 |
0: 'No DR',
|
3 |
1: 'Mild',
|
|
|
1 |
+
""" Mapping of class IDs to lables. """
|
2 |
+
|
3 |
DR_LABELMAP = {
|
4 |
0: 'No DR',
|
5 |
1: 'Mild',
|
requirements.txt
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
+
aiofiles==23.2.1
|
3 |
+
altair==5.2.0
|
4 |
+
annotated-types==0.6.0
|
5 |
+
anyio==4.2.0
|
6 |
+
attrs==23.2.0
|
7 |
+
certifi==2023.11.17
|
8 |
+
charset-normalizer==3.3.2
|
9 |
+
click==8.1.7
|
10 |
+
colorama==0.4.6
|
11 |
+
contourpy==1.2.0
|
12 |
+
cycler==0.12.1
|
13 |
+
exceptiongroup==1.2.0
|
14 |
+
fastapi==0.108.0
|
15 |
+
ffmpy==0.3.1
|
16 |
+
filelock==3.13.1
|
17 |
+
fonttools==4.47.0
|
18 |
+
fsspec==2023.12.2
|
19 |
+
gradio==4.13.0
|
20 |
+
gradio_client==0.8.0
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==1.0.2
|
23 |
+
httpx==0.26.0
|
24 |
+
huggingface-hub==0.20.2
|
25 |
+
idna==3.6
|
26 |
+
importlib-resources==6.1.1
|
27 |
+
Jinja2==3.1.2
|
28 |
+
jsonschema==4.20.0
|
29 |
+
jsonschema-specifications==2023.12.1
|
30 |
+
kiwisolver==1.4.5
|
31 |
+
markdown-it-py==3.0.0
|
32 |
+
MarkupSafe==2.1.3
|
33 |
+
matplotlib==3.8.2
|
34 |
+
mdurl==0.1.2
|
35 |
+
mpmath==1.3.0
|
36 |
+
networkx==3.2.1
|
37 |
+
numpy==1.26.3
|
38 |
+
orjson==3.9.10
|
39 |
+
packaging==23.2
|
40 |
+
pandas==2.1.4
|
41 |
+
pillow==10.2.0
|
42 |
+
pydantic==2.5.3
|
43 |
+
pydantic_core==2.14.6
|
44 |
+
pydub==0.25.1
|
45 |
+
Pygments==2.17.2
|
46 |
+
pyparsing==3.1.1
|
47 |
+
python-dateutil==2.8.2
|
48 |
+
python-multipart==0.0.6
|
49 |
+
pytz==2023.3.post1
|
50 |
+
PyYAML==6.0.1
|
51 |
+
referencing==0.32.1
|
52 |
+
regex==2023.12.25
|
53 |
+
requests==2.31.0
|
54 |
+
rich==13.7.0
|
55 |
+
rpds-py==0.16.2
|
56 |
+
safetensors==0.4.1
|
57 |
+
semantic-version==2.10.0
|
58 |
+
shellingham==1.5.4
|
59 |
+
six==1.16.0
|
60 |
+
sniffio==1.3.0
|
61 |
+
starlette==0.32.0.post1
|
62 |
+
sympy==1.12
|
63 |
+
tokenizers==0.15.0
|
64 |
+
tomlkit==0.12.0
|
65 |
+
toolz==0.12.0
|
66 |
+
torch==2.1.2+cpu
|
67 |
+
tqdm==4.66.1
|
68 |
+
transformers==4.36.2
|
69 |
+
typer==0.9.0
|
70 |
+
typing_extensions==4.9.0
|
71 |
+
tzdata==2023.4
|
72 |
+
urllib3==2.1.0
|
73 |
+
uvicorn==0.25.0
|
74 |
+
websockets==11.0.3
|
train.py
CHANGED
@@ -49,7 +49,15 @@ DataRecord = Tuple[Image.Image, int]
|
|
49 |
|
50 |
|
51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
|
|
|
52 |
def __init__(self, data_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
super().__init__()
|
54 |
|
55 |
self.data_path = data_path
|
@@ -88,21 +96,25 @@ class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
88 |
return img_path
|
89 |
|
90 |
|
|
|
91 |
class Purpose(Enum):
|
92 |
Train = 0
|
93 |
Val = 1
|
94 |
|
95 |
-
|
96 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
97 |
Callable[..., torch.Tensor]]
|
98 |
|
|
|
99 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
100 |
|
101 |
-
def normalize(arr: np.ndarray) -> np.ndarray:
|
102 |
-
return arr / np.sum(arr)
|
103 |
-
|
104 |
|
105 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
|
|
|
|
|
|
|
|
|
106 |
def __init__(self, dataset: RetinopathyDataset,
|
107 |
indices: np.ndarray,
|
108 |
purpose: Purpose,
|
@@ -111,7 +123,24 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
111 |
stratify_classes: bool = False,
|
112 |
use_log_frequencies: bool = False,
|
113 |
):
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
self.dataset = dataset
|
116 |
self.indices = indices
|
117 |
self.purpose = purpose
|
@@ -124,22 +153,26 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
124 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
125 |
self.frequencies: Optional[Dict[int, float]] = None
|
126 |
if self.stratify_classes:
|
127 |
-
self.
|
128 |
if self.use_log_frequencies:
|
129 |
-
self.
|
130 |
|
131 |
-
def
|
132 |
assert self.per_class_indices is not None
|
133 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
134 |
counts = np.array(list(counts_dict.values()))
|
135 |
-
counts_nrm =
|
136 |
temperature = 50.0 # > 1 to even-out frequencies
|
137 |
-
freqs =
|
138 |
self.frequencies = {k: freq.item() for k, freq
|
139 |
in zip(self.per_class_indices.keys(), freqs)}
|
140 |
print(self.frequencies)
|
141 |
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
buckets = defaultdict(list)
|
144 |
for index in self.indices:
|
145 |
label = self.dataset.get_label_at(index)
|
@@ -191,6 +224,14 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
191 |
seed: int = 54,
|
192 |
) -> Tuple['Split', 'Split']:
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
prng = RandomState(seed)
|
195 |
|
196 |
num_train = int(len(all_data) * train_fraction)
|
@@ -204,7 +245,8 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
204 |
return train_data, val_data
|
205 |
|
206 |
|
207 |
-
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
|
|
208 |
labels = []
|
209 |
for _, label in dataset:
|
210 |
if isinstance(label, torch.Tensor):
|
@@ -261,7 +303,16 @@ class Metrics:
|
|
261 |
return self
|
262 |
|
263 |
|
264 |
-
def worker_init_fn(worker_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
state = np.random.get_state()
|
266 |
assert isinstance(state, tuple)
|
267 |
assert isinstance(state[1], np.ndarray)
|
@@ -274,6 +325,7 @@ def worker_init_fn(worker_id):
|
|
274 |
|
275 |
|
276 |
class ViTLightningModule(L.LightningModule):
|
|
|
277 |
def __init__(self, debug: bool) -> None:
|
278 |
super().__init__()
|
279 |
|
@@ -443,6 +495,7 @@ class ViTLightningModule(L.LightningModule):
|
|
443 |
return loss
|
444 |
|
445 |
def _dump_train_images(self) -> None:
|
|
|
446 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
447 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
448 |
img_np = img.cpu().numpy()
|
@@ -494,18 +547,19 @@ class ViTLightningModule(L.LightningModule):
|
|
494 |
|
495 |
|
496 |
def main():
|
|
|
497 |
|
498 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
499 |
parser.add_argument('--tag', action='store', type=str,
|
500 |
help='Extra suffix to put on the artefact dir name')
|
501 |
-
parser.add_argument('--debug', action='store_true'
|
|
|
502 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
503 |
help='Convert a checkpoint from training to pickle-independent '
|
504 |
'predictor-compatible directory')
|
505 |
|
506 |
args = parser.parse_args()
|
507 |
|
508 |
-
|
509 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
510 |
|
511 |
if args.convert_checkpoint is not None:
|
|
|
49 |
|
50 |
|
51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
52 |
+
""" A class to access the pre-downloaded Diabetic Retinopathy dataset. """
|
53 |
+
|
54 |
def __init__(self, data_path: str) -> None:
|
55 |
+
""" Constructor.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
data_path (str): path to the dataset, ex: "retinopathy_data"
|
59 |
+
containing "trainLabels.csv" and "train/".
|
60 |
+
"""
|
61 |
super().__init__()
|
62 |
|
63 |
self.data_path = data_path
|
|
|
96 |
return img_path
|
97 |
|
98 |
|
99 |
+
""" Purpose of a split: training or validation. """
|
100 |
class Purpose(Enum):
|
101 |
Train = 0
|
102 |
Val = 1
|
103 |
|
104 |
+
""" Augmentation transformations for an image and a label. """
|
105 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
106 |
Callable[..., torch.Tensor]]
|
107 |
|
108 |
+
""" Feature (image) and target (label) tensors. """
|
109 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
110 |
|
|
|
|
|
|
|
111 |
|
112 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
113 |
+
""" Split is a class that keep a view on a part of a dataset.
|
114 |
+
Split is used to hold the imormation about which samples go to training
|
115 |
+
and which to validation without a need to put these groups of files into
|
116 |
+
separate folders.
|
117 |
+
"""
|
118 |
def __init__(self, dataset: RetinopathyDataset,
|
119 |
indices: np.ndarray,
|
120 |
purpose: Purpose,
|
|
|
123 |
stratify_classes: bool = False,
|
124 |
use_log_frequencies: bool = False,
|
125 |
):
|
126 |
+
""" Constructor.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
dataset (RetinopathyDataset): The dataset on which the Split "views".
|
130 |
+
indices (np.ndarray): Externally provided indices of samples that
|
131 |
+
are "viewed" on.
|
132 |
+
purpose (Purpose): Either train or val, to be able to replicate
|
133 |
+
the data for train split for effecient workers utilization.
|
134 |
+
transforms (FeatureAndTargetTransforms): Functors of feature and
|
135 |
+
target transforms.
|
136 |
+
oversample_factor (int, optional): Expand the training dataset by
|
137 |
+
replication to avoid dataloader stalls on epoch ends. Defaults to 1.
|
138 |
+
stratify_classes (bool, optional): Whether to apply stratified sampling.
|
139 |
+
Defaults to False.
|
140 |
+
use_log_frequencies (bool, optional): If stratify_classes=True,
|
141 |
+
whether to use logarithmic sampling strategy. If False, apply
|
142 |
+
regular even sampling. Defaults to False.
|
143 |
+
"""
|
144 |
self.dataset = dataset
|
145 |
self.indices = indices
|
146 |
self.purpose = purpose
|
|
|
153 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
154 |
self.frequencies: Optional[Dict[int, float]] = None
|
155 |
if self.stratify_classes:
|
156 |
+
self._bucketize_indices()
|
157 |
if self.use_log_frequencies:
|
158 |
+
self._calc_frequencies()
|
159 |
|
160 |
+
def _calc_frequencies(self):
|
161 |
assert self.per_class_indices is not None
|
162 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
163 |
counts = np.array(list(counts_dict.values()))
|
164 |
+
counts_nrm = self._normalize(counts)
|
165 |
temperature = 50.0 # > 1 to even-out frequencies
|
166 |
+
freqs = self._normalize(np.log1p(counts_nrm * temperature))
|
167 |
self.frequencies = {k: freq.item() for k, freq
|
168 |
in zip(self.per_class_indices.keys(), freqs)}
|
169 |
print(self.frequencies)
|
170 |
|
171 |
+
@staticmethod
|
172 |
+
def _normalize(arr: np.ndarray) -> np.ndarray:
|
173 |
+
return arr / np.sum(arr)
|
174 |
+
|
175 |
+
def _bucketize_indices(self):
|
176 |
buckets = defaultdict(list)
|
177 |
for index in self.indices:
|
178 |
label = self.dataset.get_label_at(index)
|
|
|
224 |
seed: int = 54,
|
225 |
) -> Tuple['Split', 'Split']:
|
226 |
|
227 |
+
""" Prepare train and val splits deterministically.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
Tuple[Split, Split]:
|
231 |
+
- Train split
|
232 |
+
- Val split
|
233 |
+
"""
|
234 |
+
|
235 |
prng = RandomState(seed)
|
236 |
|
237 |
num_train = int(len(all_data) * train_fraction)
|
|
|
245 |
return train_data, val_data
|
246 |
|
247 |
|
248 |
+
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
249 |
+
split_name: str) -> None:
|
250 |
labels = []
|
251 |
for _, label in dataset:
|
252 |
if isinstance(label, torch.Tensor):
|
|
|
303 |
return self
|
304 |
|
305 |
|
306 |
+
def worker_init_fn(worker_id: int) -> None:
|
307 |
+
""" Initialize workers in a way that they draw different
|
308 |
+
random samples and do not repeat identical pseudorandom
|
309 |
+
sequences of each other, which may be the case with Fork
|
310 |
+
multiprocessing.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
worker_id (int): id of a preprocessing worker process launched
|
314 |
+
by one DDP training process.
|
315 |
+
"""
|
316 |
state = np.random.get_state()
|
317 |
assert isinstance(state, tuple)
|
318 |
assert isinstance(state[1], np.ndarray)
|
|
|
325 |
|
326 |
|
327 |
class ViTLightningModule(L.LightningModule):
|
328 |
+
""" Lightning Module that implements neural network training hooks. """
|
329 |
def __init__(self, debug: bool) -> None:
|
330 |
super().__init__()
|
331 |
|
|
|
495 |
return loss
|
496 |
|
497 |
def _dump_train_images(self) -> None:
|
498 |
+
""" Save augmented images to disk for inspection. """
|
499 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
500 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
501 |
img_np = img.cpu().numpy()
|
|
|
547 |
|
548 |
|
549 |
def main():
|
550 |
+
""" Neural network trainer entry point. """
|
551 |
|
552 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
553 |
parser.add_argument('--tag', action='store', type=str,
|
554 |
help='Extra suffix to put on the artefact dir name')
|
555 |
+
parser.add_argument('--debug', action='store_true',
|
556 |
+
help="Dummy training cycle for testing purposes")
|
557 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
558 |
help='Convert a checkpoint from training to pickle-independent '
|
559 |
'predictor-compatible directory')
|
560 |
|
561 |
args = parser.parse_args()
|
562 |
|
|
|
563 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
564 |
|
565 |
if args.convert_checkpoint is not None:
|