Dmitrii Khizbullin commited on
Commit
d2c9a97
·
unverified ·
2 Parent(s): 3a95ae1 6d2185b

Setup instructions, docstrings

Browse files
Files changed (6) hide show
  1. README.md +45 -1
  2. app.py +35 -3
  3. environment.yml +247 -18
  4. labelmap.py +2 -0
  5. requirements.txt +74 -0
  6. train.py +69 -15
README.md CHANGED
@@ -1 +1,45 @@
1
- # diabetic-retinopathy-detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diabetic Retinopathy Detection with AI
2
+
3
+ ## Setup
4
+
5
+ ### Gradio app environment
6
+
7
+ Install from pip requirements file:
8
+
9
+ ```bash
10
+ conda create -n retinopathy_app python=3.10
11
+ conda activate retinopathy_app
12
+ pip install -r requirements.txt
13
+ python app.py
14
+ ```
15
+
16
+ Install manually:
17
+
18
+ ```bash
19
+ pip install pytorch --index-url https://download.pytorch.org/whl/cpu
20
+ pip install gradio
21
+ pip install transformers
22
+ ```
23
+
24
+ ### Training environment
25
+
26
+ Create conda environment from YAML:
27
+ ```bash
28
+ mamba env create -n retinopathy_train -f environment.yml
29
+ ```
30
+
31
+ Download the data from [Kaggle](https://www.kaggle.com/competitions/diabetic-retinopathy-detection/data) or use kaggle API:
32
+
33
+ ```bash
34
+ pip install kaggle
35
+ kaggle competitions download -c diabetic-retinopathy-detection
36
+ mkdir retinopathy_data/
37
+ unzip diabetic-retinopathy-detection.zip -d retinopathy_data/
38
+ ```
39
+
40
+ Launch training:
41
+ ```bash
42
+ conda activate retinopathy_train
43
+ python train.py
44
+ ```
45
+ The trained model will be put into `lightning_logs/`.
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- from typing import Tuple, Optional, Dict, List
6
  import glob
7
  from collections import defaultdict
8
 
@@ -13,7 +13,10 @@ from labelmap import DR_LABELMAP
13
 
14
 
15
  class App:
 
 
16
  def __init__(self) -> None:
 
17
 
18
  ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
19
 
@@ -41,7 +44,7 @@ class App:
41
  output = gr.Label(num_top_classes=len(DR_LABELMAP),
42
  label="Retinopathy level prediction")
43
  with gr.Column(scale=4):
44
- gr.Markdown("![](https://media.githubusercontent.com/media/Obs01ete/retinopathy/master/media/logo1.png)")
45
  with gr.Row():
46
  with gr.Column(scale=9, min_width=100):
47
  image = gr.Image(label="Retina scan")
@@ -66,9 +69,19 @@ class App:
66
  self.ui = ui
67
 
68
  def launch(self) -> None:
 
69
  self.ui.queue().launch(share=True)
70
 
71
- def predict(self, image: Optional[np.ndarray]):
 
 
 
 
 
 
 
 
 
72
  if image is None:
73
  return dict()
74
  cls_name, prob, probs = self._infer(image)
@@ -79,6 +92,19 @@ class App:
79
  return probs_dict
80
 
81
  def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  assert isinstance(self.model, ResNetForImageClassification)
83
 
84
  inputs = self.image_processor(image_chw, return_tensors="pt")
@@ -98,6 +124,11 @@ class App:
98
 
99
  @staticmethod
100
  def _load_example_lists() -> Dict[int, List[str]]:
 
 
 
 
 
101
 
102
  example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
103
 
@@ -115,6 +146,7 @@ class App:
115
 
116
 
117
  def main():
 
118
  app = App()
119
  app.launch()
120
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ from typing import Tuple, Optional, Dict, List, Dict
6
  import glob
7
  from collections import defaultdict
8
 
 
13
 
14
 
15
  class App:
16
+ """ Demonstration of the Diabetic Retinopathy model as a Gradio app. """
17
+
18
  def __init__(self) -> None:
19
+ """ Constructor. """
20
 
21
  ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
22
 
 
44
  output = gr.Label(num_top_classes=len(DR_LABELMAP),
45
  label="Retinopathy level prediction")
46
  with gr.Column(scale=4):
47
+ gr.Markdown("![](https://media.githubusercontent.com/media/SDAIA-KAUST-AI/diabetic-retinopathy-detection/main/media/logo1.png)")
48
  with gr.Row():
49
  with gr.Column(scale=9, min_width=100):
50
  image = gr.Image(label="Retina scan")
 
69
  self.ui = ui
70
 
71
  def launch(self) -> None:
72
+ """ Launch the application, blocking. """
73
  self.ui.queue().launch(share=True)
74
 
75
+ def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]:
76
+ """ Gradio callback for pricessing of an image.
77
+
78
+ Args:
79
+ image (Optional[np.ndarray]): Provided image.
80
+
81
+ Returns:
82
+ Dict[str, float]: Label-compatible dict.
83
+ """
84
+
85
  if image is None:
86
  return dict()
87
  cls_name, prob, probs = self._infer(image)
 
92
  return probs_dict
93
 
94
  def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
95
+ """ Low-level method to perform neural network inference.
96
+
97
+ Args:
98
+ image_chw (np.ndarray): Provided image.
99
+
100
+ Returns:
101
+ Tuple[str, float, np.ndarray]:
102
+ - Most probable class name
103
+ - Probability of the most probable class name.
104
+ - Probablilities of all classes in the order of
105
+ being listed in the label map.
106
+ """
107
+
108
  assert isinstance(self.model, ResNetForImageClassification)
109
 
110
  inputs = self.image_processor(image_chw, return_tensors="pt")
 
124
 
125
  @staticmethod
126
  def _load_example_lists() -> Dict[int, List[str]]:
127
+ """ Load example retina images from disk.
128
+
129
+ Returns:
130
+ Dict[int, List[str]]: Dictionary of cls_id -> list of images paths.
131
+ """
132
 
133
  example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
134
 
 
146
 
147
 
148
  def main():
149
+ """ App entry point. """
150
  app = App()
151
  app.launch()
152
 
environment.yml CHANGED
@@ -1,22 +1,251 @@
1
- name: diabetic-retinopathy-detection
2
-
3
  channels:
4
- - pytorch
5
- - nvidia
6
  - conda-forge
7
  - defaults
8
-
9
  dependencies:
10
- - dask
11
- - gh
12
- - git
13
- - kaggle
14
- - jupyterlab
15
- - jupyterlab-nvdashboard
16
- - lightning
17
- - nbgitpuller
18
- - pytorch
19
- - pytorch-cuda=11.8
20
- - scikit-learn
21
- - torchvision
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: retinopathy
 
2
  channels:
3
+ - anaconda
 
4
  - conda-forge
5
  - defaults
 
6
  dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - aiofiles=22.1.0=py310h06a4308_0
10
+ - aiosqlite=0.18.0=py310h06a4308_0
11
+ - argon2-cffi=21.3.0=pyhd3eb1b0_0
12
+ - argon2-cffi-bindings=21.2.0=py310h7f8727e_0
13
+ - asttokens=2.0.5=pyhd3eb1b0_0
14
+ - attrs=23.1.0=py310h06a4308_0
15
+ - babel=2.11.0=py310h06a4308_0
16
+ - backcall=0.2.0=pyhd3eb1b0_0
17
+ - beautifulsoup4=4.12.2=py310h06a4308_0
18
+ - bleach=4.1.0=pyhd3eb1b0_0
19
+ - brotli-python=1.0.9=py310h6a678d5_7
20
+ - bzip2=1.0.8=h7b6447c_0
21
+ - ca-certificates=2023.08.22=h06a4308_0
22
+ - certifi=2023.11.17=py310h06a4308_0
23
+ - cffi=1.16.0=py310h5eee18b_0
24
+ - comm=0.1.2=py310h06a4308_0
25
+ - cryptography=41.0.3=py310hdda0065_0
26
+ - debugpy=1.6.7=py310h6a678d5_0
27
+ - decorator=5.1.1=pyhd3eb1b0_0
28
+ - defusedxml=0.7.1=pyhd3eb1b0_0
29
+ - executing=0.8.3=pyhd3eb1b0_0
30
+ - ipykernel=6.25.0=py310h2f386ee_0
31
+ - ipython=8.15.0=py310h06a4308_0
32
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
33
+ - jedi=0.18.1=py310h06a4308_1
34
+ - jinja2=3.1.2=py310h06a4308_0
35
+ - json5=0.9.6=pyhd3eb1b0_0
36
+ - jsonschema=4.19.2=py310h06a4308_0
37
+ - jsonschema-specifications=2023.7.1=py310h06a4308_0
38
+ - jupyter_client=8.6.0=py310h06a4308_0
39
+ - jupyter_core=5.5.0=py310h06a4308_0
40
+ - jupyter_events=0.8.0=py310h06a4308_0
41
+ - jupyter_server=2.10.0=py310h06a4308_0
42
+ - jupyter_server_fileid=0.9.0=py310h06a4308_0
43
+ - jupyter_server_terminals=0.4.4=py310h06a4308_1
44
+ - jupyter_server_ydoc=0.8.0=py310h06a4308_1
45
+ - jupyter_ydoc=0.2.4=py310h06a4308_0
46
+ - jupyterlab=3.6.3=py310h06a4308_0
47
+ - jupyterlab_pygments=0.2.2=py310h06a4308_0
48
+ - jupyterlab_server=2.25.1=py310h06a4308_0
49
+ - ld_impl_linux-64=2.38=h1181459_1
50
+ - libffi=3.4.4=h6a678d5_0
51
+ - libgcc-ng=11.2.0=h1234567_1
52
+ - libgomp=11.2.0=h1234567_1
53
+ - libsodium=1.0.18=h7b6447c_0
54
+ - libstdcxx-ng=11.2.0=h1234567_1
55
+ - libuuid=1.41.5=h5eee18b_0
56
+ - matplotlib-inline=0.1.6=py310h06a4308_0
57
+ - mistune=2.0.4=py310h06a4308_0
58
+ - nbclassic=1.0.0=py310h06a4308_0
59
+ - nbclient=0.8.0=py310h06a4308_0
60
+ - nbconvert=7.10.0=py310h06a4308_0
61
+ - nbformat=5.9.2=py310h06a4308_0
62
+ - ncurses=6.4=h6a678d5_0
63
+ - nest-asyncio=1.5.6=py310h06a4308_0
64
+ - notebook=6.5.4=py310h06a4308_0
65
+ - notebook-shim=0.2.3=py310h06a4308_0
66
+ - openssl=3.0.12=h7f8727e_0
67
+ - overrides=7.4.0=py310h06a4308_0
68
+ - pandocfilters=1.5.0=pyhd3eb1b0_0
69
+ - parso=0.8.3=pyhd3eb1b0_0
70
+ - pexpect=4.8.0=pyhd3eb1b0_3
71
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
72
+ - platformdirs=3.10.0=py310h06a4308_0
73
+ - prometheus_client=0.14.1=py310h06a4308_0
74
+ - prompt-toolkit=3.0.36=py310h06a4308_0
75
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
76
+ - pure_eval=0.2.2=pyhd3eb1b0_0
77
+ - pycparser=2.21=pyhd3eb1b0_0
78
+ - pyopenssl=23.2.0=py310h06a4308_0
79
+ - pysocks=1.7.1=py310h06a4308_0
80
+ - python=3.10.13=h955ad1f_0
81
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
82
+ - python-fastjsonschema=2.16.2=py310h06a4308_0
83
+ - python-json-logger=2.0.7=py310h06a4308_0
84
+ - pytz=2023.3.post1=py310h06a4308_0
85
+ - pyyaml=6.0.1=py310h5eee18b_0
86
+ - pyzmq=25.1.0=py310h6a678d5_0
87
+ - readline=8.2=h5eee18b_0
88
+ - referencing=0.30.2=py310h06a4308_0
89
+ - requests=2.31.0=py310h06a4308_0
90
+ - rfc3339-validator=0.1.4=py310h06a4308_0
91
+ - rfc3986-validator=0.1.1=py310h06a4308_0
92
+ - rpds-py=0.10.6=py310hb02cf49_0
93
+ - send2trash=1.8.2=py310h06a4308_0
94
+ - setuptools=68.0.0=py310h06a4308_0
95
+ - six=1.16.0=pyhd3eb1b0_1
96
+ - soupsieve=2.5=py310h06a4308_0
97
+ - sqlite=3.41.2=h5eee18b_0
98
+ - stack_data=0.2.0=pyhd3eb1b0_0
99
+ - terminado=0.17.1=py310h06a4308_0
100
+ - tinycss2=1.2.1=py310h06a4308_0
101
+ - tk=8.6.12=h1ccaba5_0
102
+ - tomli=2.0.1=py310h06a4308_0
103
+ - tornado=6.3.3=py310h5eee18b_0
104
+ - webencodings=0.5.1=py310h06a4308_1
105
+ - wheel=0.41.2=py310h06a4308_0
106
+ - y-py=0.5.9=py310h52d8a92_0
107
+ - yaml=0.2.5=h7b6447c_0
108
+ - ypy-websocket=0.8.2=py310h06a4308_0
109
+ - zeromq=4.3.4=h2531618_0
110
+ - zlib=1.2.13=h5eee18b_0
111
+ - pip:
112
+ - absl-py==2.0.0
113
+ - aiobotocore==2.8.0
114
+ - aiohttp==3.9.1
115
+ - aioitertools==0.11.0
116
+ - aiosignal==1.3.1
117
+ - altair==5.2.0
118
+ - annotated-types==0.6.0
119
+ - antlr4-python3-runtime==4.9.3
120
+ - anyio==3.7.1
121
+ - arrow==1.3.0
122
+ - async-timeout==4.0.3
123
+ - backoff==2.2.1
124
+ - bitsandbytes==0.41.3
125
+ - blessed==1.20.0
126
+ - boto3==1.33.1
127
+ - botocore==1.33.1
128
+ - cachetools==5.3.2
129
+ - chardet==5.2.0
130
+ - charset-normalizer==3.3.2
131
+ - click==8.1.7
132
+ - colorama==0.4.6
133
+ - contourpy==1.2.0
134
+ - croniter==1.4.1
135
+ - cycler==0.12.1
136
+ - dateutils==0.6.12
137
+ - deepdiff==6.7.1
138
+ - docker==6.1.3
139
+ - docstring-parser==0.15
140
+ - exceptiongroup==1.2.0
141
+ - fastapi==0.104.1
142
+ - ffmpy==0.3.1
143
+ - filelock==3.13.1
144
+ - fonttools==4.46.0
145
+ - frozenlist==1.4.0
146
+ - fsspec==2023.12.1
147
+ - google-auth==2.25.1
148
+ - google-auth-oauthlib==1.1.0
149
+ - gradio==4.12.0
150
+ - gradio-client==0.8.0
151
+ - grpcio==1.59.3
152
+ - h11==0.14.0
153
+ - httpcore==1.0.2
154
+ - httpx==0.26.0
155
+ - huggingface-hub==0.19.4
156
+ - hydra-core==1.3.2
157
+ - idna==3.6
158
+ - importlib-resources==6.1.1
159
+ - inquirer==3.1.4
160
+ - itsdangerous==2.1.2
161
+ - jmespath==1.0.1
162
+ - jsonargparse==4.27.1
163
+ - kiwisolver==1.4.5
164
+ - lightning==2.1.2
165
+ - lightning-api-access==0.0.5
166
+ - lightning-cloud==0.5.52
167
+ - lightning-fabric==2.1.2
168
+ - lightning-utilities==0.10.0
169
+ - line-profiler==4.1.2
170
+ - markdown==3.5.1
171
+ - markdown-it-py==3.0.0
172
+ - markupsafe==2.1.3
173
+ - matplotlib==3.8.2
174
+ - mdurl==0.1.2
175
+ - mpmath==1.3.0
176
+ - multidict==6.0.4
177
+ - networkx==3.2.1
178
+ - numpy==1.26.2
179
+ - nvidia-cublas-cu12==12.1.3.1
180
+ - nvidia-cuda-cupti-cu12==12.1.105
181
+ - nvidia-cuda-nvrtc-cu12==12.1.105
182
+ - nvidia-cuda-runtime-cu12==12.1.105
183
+ - nvidia-cudnn-cu12==8.9.2.26
184
+ - nvidia-cufft-cu12==11.0.2.54
185
+ - nvidia-curand-cu12==10.3.2.106
186
+ - nvidia-cusolver-cu12==11.4.5.107
187
+ - nvidia-cusparse-cu12==12.1.0.106
188
+ - nvidia-nccl-cu12==2.18.1
189
+ - nvidia-nvjitlink-cu12==12.3.101
190
+ - nvidia-nvtx-cu12==12.1.105
191
+ - oauthlib==3.2.2
192
+ - omegaconf==2.3.0
193
+ - ordered-set==4.1.0
194
+ - orjson==3.9.10
195
+ - packaging==23.2
196
+ - pandas==2.1.3
197
+ - pillow==10.1.0
198
+ - protobuf==4.23.4
199
+ - psutil==5.9.6
200
+ - pyasn1==0.5.1
201
+ - pyasn1-modules==0.3.0
202
+ - pydantic==2.5.2
203
+ - pydantic-core==2.14.5
204
+ - pydub==0.25.1
205
+ - pygments==2.17.2
206
+ - pyjwt==2.8.0
207
+ - pyparsing==3.1.1
208
+ - python-editor==1.0.4
209
+ - python-multipart==0.0.6
210
+ - pytorch-lightning==2.1.2
211
+ - readchar==4.0.5
212
+ - redis==5.0.1
213
+ - regex==2023.10.3
214
+ - requests-oauthlib==1.3.1
215
+ - rich==13.7.0
216
+ - rsa==4.9
217
+ - s3fs==2023.12.1
218
+ - s3transfer==0.8.0
219
+ - safetensors==0.4.1
220
+ - semantic-version==2.10.0
221
+ - shellingham==1.5.4
222
+ - sniffio==1.3.0
223
+ - starlette==0.27.0
224
+ - starsessions==1.3.0
225
+ - sympy==1.12
226
+ - tensorboard==2.15.1
227
+ - tensorboard-data-server==0.7.2
228
+ - tensorboardx==2.6.2.2
229
+ - tokenizers==0.15.0
230
+ - tomlkit==0.12.0
231
+ - toolz==0.12.0
232
+ - torch==2.1.1
233
+ - torchmetrics==1.2.1
234
+ - torchvision==0.16.1
235
+ - tqdm==4.66.1
236
+ - traitlets==5.14.0
237
+ - transformers==4.35.2
238
+ - triton==2.1.0
239
+ - typer==0.9.0
240
+ - types-python-dateutil==2.8.19.14
241
+ - typeshed-client==2.4.0
242
+ - typing-extensions==4.9.0
243
+ - tzdata==2023.3
244
+ - urllib3==2.0.7
245
+ - uvicorn==0.24.0.post1
246
+ - wcwidth==0.2.12
247
+ - websocket-client==1.7.0
248
+ - websockets==11.0.3
249
+ - werkzeug==3.0.1
250
+ - wrapt==1.16.0
251
+ - yarl==1.9.4
labelmap.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  DR_LABELMAP = {
2
  0: 'No DR',
3
  1: 'Mild',
 
1
+ """ Mapping of class IDs to lables. """
2
+
3
  DR_LABELMAP = {
4
  0: 'No DR',
5
  1: 'Mild',
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==4.2.0
6
+ attrs==23.2.0
7
+ certifi==2023.11.17
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ contourpy==1.2.0
12
+ cycler==0.12.1
13
+ exceptiongroup==1.2.0
14
+ fastapi==0.108.0
15
+ ffmpy==0.3.1
16
+ filelock==3.13.1
17
+ fonttools==4.47.0
18
+ fsspec==2023.12.2
19
+ gradio==4.13.0
20
+ gradio_client==0.8.0
21
+ h11==0.14.0
22
+ httpcore==1.0.2
23
+ httpx==0.26.0
24
+ huggingface-hub==0.20.2
25
+ idna==3.6
26
+ importlib-resources==6.1.1
27
+ Jinja2==3.1.2
28
+ jsonschema==4.20.0
29
+ jsonschema-specifications==2023.12.1
30
+ kiwisolver==1.4.5
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.3
33
+ matplotlib==3.8.2
34
+ mdurl==0.1.2
35
+ mpmath==1.3.0
36
+ networkx==3.2.1
37
+ numpy==1.26.3
38
+ orjson==3.9.10
39
+ packaging==23.2
40
+ pandas==2.1.4
41
+ pillow==10.2.0
42
+ pydantic==2.5.3
43
+ pydantic_core==2.14.6
44
+ pydub==0.25.1
45
+ Pygments==2.17.2
46
+ pyparsing==3.1.1
47
+ python-dateutil==2.8.2
48
+ python-multipart==0.0.6
49
+ pytz==2023.3.post1
50
+ PyYAML==6.0.1
51
+ referencing==0.32.1
52
+ regex==2023.12.25
53
+ requests==2.31.0
54
+ rich==13.7.0
55
+ rpds-py==0.16.2
56
+ safetensors==0.4.1
57
+ semantic-version==2.10.0
58
+ shellingham==1.5.4
59
+ six==1.16.0
60
+ sniffio==1.3.0
61
+ starlette==0.32.0.post1
62
+ sympy==1.12
63
+ tokenizers==0.15.0
64
+ tomlkit==0.12.0
65
+ toolz==0.12.0
66
+ torch==2.1.2+cpu
67
+ tqdm==4.66.1
68
+ transformers==4.36.2
69
+ typer==0.9.0
70
+ typing_extensions==4.9.0
71
+ tzdata==2023.4
72
+ urllib3==2.1.0
73
+ uvicorn==0.25.0
74
+ websockets==11.0.3
train.py CHANGED
@@ -49,7 +49,15 @@ DataRecord = Tuple[Image.Image, int]
49
 
50
 
51
  class RetinopathyDataset(data.Dataset[DataRecord]):
 
 
52
  def __init__(self, data_path: str) -> None:
 
 
 
 
 
 
53
  super().__init__()
54
 
55
  self.data_path = data_path
@@ -88,21 +96,25 @@ class RetinopathyDataset(data.Dataset[DataRecord]):
88
  return img_path
89
 
90
 
 
91
  class Purpose(Enum):
92
  Train = 0
93
  Val = 1
94
 
95
-
96
  FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
97
  Callable[..., torch.Tensor]]
98
 
 
99
  TensorRecord = Tuple[torch.Tensor, torch.Tensor]
100
 
101
- def normalize(arr: np.ndarray) -> np.ndarray:
102
- return arr / np.sum(arr)
103
-
104
 
105
  class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
 
 
 
 
 
106
  def __init__(self, dataset: RetinopathyDataset,
107
  indices: np.ndarray,
108
  purpose: Purpose,
@@ -111,7 +123,24 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
111
  stratify_classes: bool = False,
112
  use_log_frequencies: bool = False,
113
  ):
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  self.dataset = dataset
116
  self.indices = indices
117
  self.purpose = purpose
@@ -124,22 +153,26 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
124
  self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
125
  self.frequencies: Optional[Dict[int, float]] = None
126
  if self.stratify_classes:
127
- self.bucketize_indices()
128
  if self.use_log_frequencies:
129
- self.calc_frequencies()
130
 
131
- def calc_frequencies(self):
132
  assert self.per_class_indices is not None
133
  counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
134
  counts = np.array(list(counts_dict.values()))
135
- counts_nrm = normalize(counts)
136
  temperature = 50.0 # > 1 to even-out frequencies
137
- freqs = normalize(np.log1p(counts_nrm * temperature))
138
  self.frequencies = {k: freq.item() for k, freq
139
  in zip(self.per_class_indices.keys(), freqs)}
140
  print(self.frequencies)
141
 
142
- def bucketize_indices(self):
 
 
 
 
143
  buckets = defaultdict(list)
144
  for index in self.indices:
145
  label = self.dataset.get_label_at(index)
@@ -191,6 +224,14 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
191
  seed: int = 54,
192
  ) -> Tuple['Split', 'Split']:
193
 
 
 
 
 
 
 
 
 
194
  prng = RandomState(seed)
195
 
196
  num_train = int(len(all_data) * train_fraction)
@@ -204,7 +245,8 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
204
  return train_data, val_data
205
 
206
 
207
- def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader], split_name: str) -> None:
 
208
  labels = []
209
  for _, label in dataset:
210
  if isinstance(label, torch.Tensor):
@@ -261,7 +303,16 @@ class Metrics:
261
  return self
262
 
263
 
264
- def worker_init_fn(worker_id):
 
 
 
 
 
 
 
 
 
265
  state = np.random.get_state()
266
  assert isinstance(state, tuple)
267
  assert isinstance(state[1], np.ndarray)
@@ -274,6 +325,7 @@ def worker_init_fn(worker_id):
274
 
275
 
276
  class ViTLightningModule(L.LightningModule):
 
277
  def __init__(self, debug: bool) -> None:
278
  super().__init__()
279
 
@@ -443,6 +495,7 @@ class ViTLightningModule(L.LightningModule):
443
  return loss
444
 
445
  def _dump_train_images(self) -> None:
 
446
  img_batch, label_batch = next(iter(self._train_dataloader))
447
  for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
448
  img_np = img.cpu().numpy()
@@ -494,18 +547,19 @@ class ViTLightningModule(L.LightningModule):
494
 
495
 
496
  def main():
 
497
 
498
  parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
499
  parser.add_argument('--tag', action='store', type=str,
500
  help='Extra suffix to put on the artefact dir name')
501
- parser.add_argument('--debug', action='store_true')
 
502
  parser.add_argument('--convert-checkpoint', action='store', type=str,
503
  help='Convert a checkpoint from training to pickle-independent '
504
  'predictor-compatible directory')
505
 
506
  args = parser.parse_args()
507
 
508
-
509
  torch.set_float32_matmul_precision('high') # for V100/A100
510
 
511
  if args.convert_checkpoint is not None:
 
49
 
50
 
51
  class RetinopathyDataset(data.Dataset[DataRecord]):
52
+ """ A class to access the pre-downloaded Diabetic Retinopathy dataset. """
53
+
54
  def __init__(self, data_path: str) -> None:
55
+ """ Constructor.
56
+
57
+ Args:
58
+ data_path (str): path to the dataset, ex: "retinopathy_data"
59
+ containing "trainLabels.csv" and "train/".
60
+ """
61
  super().__init__()
62
 
63
  self.data_path = data_path
 
96
  return img_path
97
 
98
 
99
+ """ Purpose of a split: training or validation. """
100
  class Purpose(Enum):
101
  Train = 0
102
  Val = 1
103
 
104
+ """ Augmentation transformations for an image and a label. """
105
  FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
106
  Callable[..., torch.Tensor]]
107
 
108
+ """ Feature (image) and target (label) tensors. """
109
  TensorRecord = Tuple[torch.Tensor, torch.Tensor]
110
 
 
 
 
111
 
112
  class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
113
+ """ Split is a class that keep a view on a part of a dataset.
114
+ Split is used to hold the imormation about which samples go to training
115
+ and which to validation without a need to put these groups of files into
116
+ separate folders.
117
+ """
118
  def __init__(self, dataset: RetinopathyDataset,
119
  indices: np.ndarray,
120
  purpose: Purpose,
 
123
  stratify_classes: bool = False,
124
  use_log_frequencies: bool = False,
125
  ):
126
+ """ Constructor.
127
+
128
+ Args:
129
+ dataset (RetinopathyDataset): The dataset on which the Split "views".
130
+ indices (np.ndarray): Externally provided indices of samples that
131
+ are "viewed" on.
132
+ purpose (Purpose): Either train or val, to be able to replicate
133
+ the data for train split for effecient workers utilization.
134
+ transforms (FeatureAndTargetTransforms): Functors of feature and
135
+ target transforms.
136
+ oversample_factor (int, optional): Expand the training dataset by
137
+ replication to avoid dataloader stalls on epoch ends. Defaults to 1.
138
+ stratify_classes (bool, optional): Whether to apply stratified sampling.
139
+ Defaults to False.
140
+ use_log_frequencies (bool, optional): If stratify_classes=True,
141
+ whether to use logarithmic sampling strategy. If False, apply
142
+ regular even sampling. Defaults to False.
143
+ """
144
  self.dataset = dataset
145
  self.indices = indices
146
  self.purpose = purpose
 
153
  self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
154
  self.frequencies: Optional[Dict[int, float]] = None
155
  if self.stratify_classes:
156
+ self._bucketize_indices()
157
  if self.use_log_frequencies:
158
+ self._calc_frequencies()
159
 
160
+ def _calc_frequencies(self):
161
  assert self.per_class_indices is not None
162
  counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
163
  counts = np.array(list(counts_dict.values()))
164
+ counts_nrm = self._normalize(counts)
165
  temperature = 50.0 # > 1 to even-out frequencies
166
+ freqs = self._normalize(np.log1p(counts_nrm * temperature))
167
  self.frequencies = {k: freq.item() for k, freq
168
  in zip(self.per_class_indices.keys(), freqs)}
169
  print(self.frequencies)
170
 
171
+ @staticmethod
172
+ def _normalize(arr: np.ndarray) -> np.ndarray:
173
+ return arr / np.sum(arr)
174
+
175
+ def _bucketize_indices(self):
176
  buckets = defaultdict(list)
177
  for index in self.indices:
178
  label = self.dataset.get_label_at(index)
 
224
  seed: int = 54,
225
  ) -> Tuple['Split', 'Split']:
226
 
227
+ """ Prepare train and val splits deterministically.
228
+
229
+ Returns:
230
+ Tuple[Split, Split]:
231
+ - Train split
232
+ - Val split
233
+ """
234
+
235
  prng = RandomState(seed)
236
 
237
  num_train = int(len(all_data) * train_fraction)
 
245
  return train_data, val_data
246
 
247
 
248
+ def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
249
+ split_name: str) -> None:
250
  labels = []
251
  for _, label in dataset:
252
  if isinstance(label, torch.Tensor):
 
303
  return self
304
 
305
 
306
+ def worker_init_fn(worker_id: int) -> None:
307
+ """ Initialize workers in a way that they draw different
308
+ random samples and do not repeat identical pseudorandom
309
+ sequences of each other, which may be the case with Fork
310
+ multiprocessing.
311
+
312
+ Args:
313
+ worker_id (int): id of a preprocessing worker process launched
314
+ by one DDP training process.
315
+ """
316
  state = np.random.get_state()
317
  assert isinstance(state, tuple)
318
  assert isinstance(state[1], np.ndarray)
 
325
 
326
 
327
  class ViTLightningModule(L.LightningModule):
328
+ """ Lightning Module that implements neural network training hooks. """
329
  def __init__(self, debug: bool) -> None:
330
  super().__init__()
331
 
 
495
  return loss
496
 
497
  def _dump_train_images(self) -> None:
498
+ """ Save augmented images to disk for inspection. """
499
  img_batch, label_batch = next(iter(self._train_dataloader))
500
  for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
501
  img_np = img.cpu().numpy()
 
547
 
548
 
549
  def main():
550
+ """ Neural network trainer entry point. """
551
 
552
  parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
553
  parser.add_argument('--tag', action='store', type=str,
554
  help='Extra suffix to put on the artefact dir name')
555
+ parser.add_argument('--debug', action='store_true',
556
+ help="Dummy training cycle for testing purposes")
557
  parser.add_argument('--convert-checkpoint', action='store', type=str,
558
  help='Convert a checkpoint from training to pickle-independent '
559
  'predictor-compatible directory')
560
 
561
  args = parser.parse_args()
562
 
 
563
  torch.set_float32_matmul_precision('high') # for V100/A100
564
 
565
  if args.convert_checkpoint is not None: