outofray commited on
Commit
2f3b6c4
·
1 Parent(s): a446599

update_repo

Browse files
Files changed (15) hide show
  1. README.md +143 -3
  2. add_clearml_yolov5.patch +215 -0
  3. dataset.py +526 -0
  4. demo.bat +12 -0
  5. demo.py +781 -0
  6. demo_headless.sh +27 -0
  7. eval.py +397 -0
  8. plots.py +303 -0
  9. predict.py +470 -0
  10. requirements.txt +26 -0
  11. roi.py +34 -0
  12. try_chart.ipynb +0 -0
  13. usgfw2wrapper.dll +0 -0
  14. weights/.keep +1 -0
  15. weights/yolov5s-v2 +1 -0
README.md CHANGED
@@ -1,3 +1,143 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- #region -->
2
+ # Automation of Aorta Measurement in Ultrasound Images
3
+
4
+ ## Env setup
5
+
6
+ Suggested hardware:
7
+
8
+ - GPU: NVIDIA RTX 3090 or higher x1 (model training using PyTorch)
9
+ - CPU: 11th Gen Intel(R) Core(TM) i9-11900KF @ 3.50GHz, or higher (model inference using OpenVINO)
10
+
11
+ Software stack:
12
+
13
+ - OS: Ubuntu 20.04 LTS
14
+ - Python: 3.8+
15
+ - Python Env: conda
16
+
17
+ ```shell
18
+ conda create -n aorta python=3.8 -y
19
+ conda activate aorta
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## Dataset
24
+
25
+ Steps to prepare the dataset:
26
+
27
+ 1. Collect images and import to CVAT
28
+ 2. Label the images in CVAT
29
+ 3. Export the labelled data in `COCO 1.0` format using CVAT
30
+
31
+ 1. Go to CVAT > `Projects` page
32
+ 2. Click `⋮` on `aorta` project
33
+ 3. Click `Export dataset`
34
+ - Format: `COCO 1.0`
35
+ - Save images: `Yes`
36
+
37
+ 4. Convert the new split data into YOLOv5 format
38
+
39
+ ```shell
40
+ python dataset.py coco2yolov5 [path/to/coco/input/dir] [path/to/yolov5/output/dir]
41
+ ```
42
+
43
+ [CVAT](https://github.com/cvat-ai/cvat/tree/v2.3.0) info, set up with docker compose
44
+
45
+ - Server version: 2.3
46
+ - Core version: 7.3.0
47
+ - Canvas version: 2.16.1
48
+ - UI version: 1.45.0
49
+
50
+ Dataset related scripts:
51
+
52
+ - [coco2yolov5seg.ipynb](../coco2yolov5seg.ipynb): Convert COCO format to YOLOv5 format for segmentation task
53
+ - [coco_merge_split.ipnb](../coco_merge_split.ipynb): Merge and split COCO format dataset
54
+
55
+ ## Training / Validation / Export
56
+
57
+ Model choice: Prefer [yolov5-seg] over [yolov7-seg] for training/validation/exporting models, performance comparison:
58
+
59
+ - yolov5s-seg, fast transfer learning (~5-10 mins for 100 epochs using RTX 3090) and CPU inference
60
+ - yolov7-seg, seems too heavy (slower inference using CPU)
61
+
62
+ Please refer to the repos of yolov5 seg & yolov7 seg for details of training/validation/exporting models.
63
+
64
+ [yolov5-seg]: https://github.com/ultralytics/yolov5/blob/master/segment/tutorial.ipynb
65
+ [yolov7-seg]: https://github.com/WongKinYiu/yolov7/tree/u7/seg
66
+
67
+ ### yolov5-seg
68
+
69
+ Tested commit:
70
+
71
+ ```shell
72
+ # Assume work dir is aorta/
73
+ git clone https://github.com/ultralytics/yolov5
74
+ cd yolov5
75
+ git checkout 23c492321290266810e08fa5ee9a23fc9d6a571f
76
+ git apply ../add_clearml_yolov5.patch
77
+ ```
78
+
79
+ As of 2023, yolov5 seg doesn't support ClearML, but there is a [PR](https://github.com/ultralytics/yolov5/pull/10752) for it. So we can manually update these files to use ClearML to track the training process, or apply [add_clearml_yolov5.patch](./add_clearml_yolov5.patch).
80
+
81
+ ```shell
82
+ # Example
83
+ ## Original training script
84
+ python segment/train.py --img 640 --batch 16 --epochs 3 --data coco128-seg.yaml --weights yolov5s-seg.pt --cache
85
+
86
+ ## Updated training script with ClearML support
87
+ python segment/train.py --project [clearml_project_name] --name [task_name] --img 640 --batch 16 --epochs 3 --data coco128-seg.yaml --weights yolov5s-seg.pt --cache
88
+ ```
89
+
90
+ ## Test video
91
+
92
+ - Test video: [Demo.mp4](./Demo.mp4)
93
+ - Tested video (mp4): Converted from the original avi using `ffmpeg`:
94
+
95
+ ```shell
96
+ ffmpeg -i "Demo.avi" -vcodec h264 -acodec aac -b:v 500k -strict -2 Demo.mp4`
97
+ ```
98
+
99
+ ## Demo (POC for 2022 Intel DevCup)
100
+
101
+ ```shell
102
+ # run demo, using openvino model
103
+ python demo.py --video Demo.mp4 --model weights/yolov5s-v2/best_openvino_model/yolov5-640-v2.xml --plot-mask --img-size 640
104
+
105
+ # or run the demo using onnx model
106
+ python demo.py --video Demo.mp4 --model weights/yolov5s-v2/yolov5-640.onnx --plot-mask --img-size 640
107
+
108
+ # or run in the headless mode, generating a recording of the demo
109
+ ./demo_headless.sh --video Demo.mp4 --model [path/to/model]
110
+ ```
111
+
112
+ ## Deploy Pyinstaller EXE
113
+
114
+ Only tested on Windows 10:
115
+
116
+ ```shell
117
+ pip install pyinstaller==5.9
118
+ pyinstaller demo.py
119
+ # (TODO) Replace the following manual steps with pyinstaller --add-data or spec file
120
+ #
121
+ # Manual copy files to dist\demo
122
+ # 1. Copy best_openvino_model folder to dist\demo\
123
+ # 2. Copy openvino files to dist\demo
124
+ # C:\Users\sa\miniforge3\envs\echo\Lib\site-packages\openvino\libs
125
+ # plugins.xml
126
+ # openvino_ir_frontend.dll
127
+ # openvino_intel_cpu_plugin.dll
128
+ # openvino_intel_gpu_plugin.dll
129
+ ```
130
+
131
+ Troubleshooting: If the deployed EXE is not working with error `ValueError: --plotlyjs argument is not a valid URL or file path:`, please move the dist folder to another location with no special characters or Chinese in the path. Reference: <https://github.com/plotly/Kaleido/issues/57>
132
+
133
+
134
+ ## Paper
135
+
136
+ https://www.nature.com/articles/s41746-024-01269-4
137
+
138
+ Chiu, IM., Chen, TY., Zheng, YC. et al. Prospective clinical evaluation of deep learning for ultrasonographic screening of abdominal aortic aneurysms. npj Digit. Med. 7, 282 (2024).
139
+ <!-- #endregion -->
140
+
141
+ ```python
142
+
143
+ ```
add_clearml_yolov5.patch ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py
2
+ index 9de1f22..93b9ba2 100644
3
+ --- a/utils/loggers/__init__.py
4
+ +++ b/utils/loggers/__init__.py
5
+ @@ -110,7 +110,7 @@ class Loggers():
6
+ if clearml and 'clearml' in self.include:
7
+ try:
8
+ self.clearml = ClearmlLogger(self.opt, self.hyp)
9
+ - except Exception:
10
+ + except Exception as e:
11
+ self.clearml = None
12
+ prefix = colorstr('ClearML: ')
13
+ LOGGER.warning(f'{prefix}WARNING ⚠️ ClearML is installed but not configured, skipping ClearML logging.'
14
+ @@ -159,10 +159,11 @@ class Loggers():
15
+ paths = self.save_dir.glob('*labels*.jpg') # training labels
16
+ if self.wandb:
17
+ self.wandb.log({'Labels': [wandb.Image(str(x), caption=x.name) for x in paths]})
18
+ - # if self.clearml:
19
+ - # pass # ClearML saves these images automatically using hooks
20
+ if self.comet_logger:
21
+ self.comet_logger.on_pretrain_routine_end(paths)
22
+ + if self.clearml:
23
+ + for path in paths:
24
+ + self.clearml.log_plot(title=path.stem, plot_path=path)
25
+
26
+ def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
27
+ log_dict = dict(zip(self.keys[:3], vals))
28
+ @@ -289,6 +290,8 @@ class Loggers():
29
+ self.wandb.finish_run()
30
+
31
+ if self.clearml and not self.opt.evolve:
32
+ + self.clearml.log_summary(dict(zip(self.keys[3:10], results)))
33
+ + [self.clearml.log_plot(title=f.stem, plot_path=f) for f in files]
34
+ self.clearml.task.update_output_model(model_path=str(best if best.exists() else last),
35
+ name='Best Model',
36
+ auto_delete_file=False)
37
+ @@ -303,6 +306,8 @@ class Loggers():
38
+ self.wandb.wandb_run.config.update(params, allow_val_change=True)
39
+ if self.comet_logger:
40
+ self.comet_logger.on_params_update(params)
41
+ + if self.clearml:
42
+ + self.clearml.task.connect(params)
43
+
44
+
45
+ class GenericLogger:
46
+ @@ -315,7 +320,7 @@ class GenericLogger:
47
+ include: loggers to include
48
+ """
49
+
50
+ - def __init__(self, opt, console_logger, include=('tb', 'wandb')):
51
+ + def __init__(self, opt, console_logger, include=('tb', 'wandb', 'clearml')):
52
+ # init default loggers
53
+ self.save_dir = Path(opt.save_dir)
54
+ self.include = include
55
+ @@ -333,6 +338,22 @@ class GenericLogger:
56
+ config=opt)
57
+ else:
58
+ self.wandb = None
59
+ +
60
+ + if clearml and 'clearml' in self.include:
61
+ + try:
62
+ + # Hyp is not available in classification mode
63
+ + if 'hyp' not in opt:
64
+ + hyp = {}
65
+ + else:
66
+ + hyp = opt.hyp
67
+ + self.clearml = ClearmlLogger(opt, hyp)
68
+ + except Exception:
69
+ + self.clearml = None
70
+ + prefix = colorstr('ClearML: ')
71
+ + LOGGER.warning(f'{prefix}WARNING ⚠️ ClearML is installed but not configured, skipping ClearML logging.'
72
+ + f' See https://github.com/ultralytics/yolov5/tree/master/utils/loggers/clearml#readme')
73
+ + else:
74
+ + self.clearml = None
75
+
76
+ def log_metrics(self, metrics, epoch):
77
+ # Log metrics dictionary to all loggers
78
+ @@ -349,6 +370,9 @@ class GenericLogger:
79
+
80
+ if self.wandb:
81
+ self.wandb.log(metrics, step=epoch)
82
+ +
83
+ + if self.clearml:
84
+ + self.clearml.log_scalars(metrics, epoch)
85
+
86
+ def log_images(self, files, name='Images', epoch=0):
87
+ # Log images to all loggers
88
+ @@ -361,6 +385,12 @@ class GenericLogger:
89
+
90
+ if self.wandb:
91
+ self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch)
92
+ +
93
+ + if self.clearml:
94
+ + if name == 'Results':
95
+ + [self.clearml.log_plot(f.stem, f) for f in files]
96
+ + else:
97
+ + self.clearml.log_debug_samples(files, title=name)
98
+
99
+ def log_graph(self, model, imgsz=(640, 640)):
100
+ # Log model graph to all loggers
101
+ @@ -373,11 +403,17 @@ class GenericLogger:
102
+ art = wandb.Artifact(name=f'run_{wandb.run.id}_model', type='model', metadata=metadata)
103
+ art.add_file(str(model_path))
104
+ wandb.log_artifact(art)
105
+ +
106
+ + if self.clearml:
107
+ + self.clearml.log_model(model_path=model_path, model_name=model_path.stem)
108
+
109
+ def update_params(self, params):
110
+ # Update the parameters logged
111
+ if self.wandb:
112
+ wandb.run.config.update(params, allow_val_change=True)
113
+ +
114
+ + if self.clearml:
115
+ + self.clearml.task.connect(params)
116
+
117
+
118
+ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
119
+ diff --git a/utils/loggers/clearml/clearml_utils.py b/utils/loggers/clearml/clearml_utils.py
120
+ index 2764abe..e7525da 100644
121
+ --- a/utils/loggers/clearml/clearml_utils.py
122
+ +++ b/utils/loggers/clearml/clearml_utils.py
123
+ @@ -3,6 +3,9 @@ import glob
124
+ import re
125
+ from pathlib import Path
126
+
127
+ +import matplotlib.image as mpimg
128
+ +import matplotlib.pyplot as plt
129
+ +
130
+ import numpy as np
131
+ import yaml
132
+
133
+ @@ -79,13 +82,16 @@ class ClearmlLogger:
134
+ # Maximum number of images to log to clearML per epoch
135
+ self.max_imgs_to_log_per_epoch = 16
136
+ # Get the interval of epochs when bounding box images should be logged
137
+ - self.bbox_interval = opt.bbox_interval
138
+ + # Only for detection task though!
139
+ + if 'bbox_interval' in opt:
140
+ + self.bbox_interval = opt.bbox_interval
141
+ self.clearml = clearml
142
+ self.task = None
143
+ self.data_dict = None
144
+ if self.clearml:
145
+ self.task = Task.init(
146
+ - project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
147
+ + # project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
148
+ + project_name=opt.project if not str(opt.project).startswith('runs/') else 'YOLOv5',
149
+ task_name=opt.name if opt.name != 'exp' else 'Training',
150
+ tags=['YOLOv5'],
151
+ output_uri=True,
152
+ @@ -112,6 +118,53 @@ class ClearmlLogger:
153
+ # Set data to data_dict because wandb will crash without this information and opt is the best way
154
+ # to give it to them
155
+ opt.data = self.data_dict
156
+ +
157
+ + def log_scalars(self, metrics, epoch):
158
+ + """
159
+ + Log scalars/metrics to ClearML
160
+ + arguments:
161
+ + metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
162
+ + epoch (int) iteration number for the current set of metrics
163
+ + """
164
+ + for k, v in metrics.items():
165
+ + title, series = k.split('/')
166
+ + self.task.get_logger().report_scalar(title, series, v, epoch)
167
+ +
168
+ + def log_model(self, model_path, model_name, epoch=0):
169
+ + """
170
+ + Log model weights to ClearML
171
+ + arguments:
172
+ + model_path (PosixPath or str) Path to the model weights
173
+ + model_name (str) Name of the model visible in ClearML
174
+ + epoch (int) Iteration / epoch of the model weights
175
+ + """
176
+ + self.task.update_output_model(model_path=str(model_path),
177
+ + name=model_name,
178
+ + iteration=epoch,
179
+ + auto_delete_file=False)
180
+ +
181
+ + def log_summary(self, metrics):
182
+ + """
183
+ + Log final metrics to a summary table
184
+ + arguments:
185
+ + metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
186
+ + """
187
+ + for k, v in metrics.items():
188
+ + self.task.get_logger().report_single_value(k, v)
189
+ +
190
+ + def log_plot(self, title, plot_path):
191
+ + """
192
+ + Log image as plot in the plot section of ClearML
193
+ + arguments:
194
+ + title (str) Title of the plot
195
+ + plot_path (PosixPath or str) Path to the saved image file
196
+ + """
197
+ + img = mpimg.imread(plot_path)
198
+ + fig = plt.figure()
199
+ + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
200
+ + ax.imshow(img)
201
+ +
202
+ + self.task.get_logger().report_matplotlib_figure(title, "", figure=fig, report_interactive=False)
203
+
204
+ def log_debug_samples(self, files, title='Debug Samples'):
205
+ """
206
+ @@ -126,7 +179,8 @@ class ClearmlLogger:
207
+ it = re.search(r'_batch(\d+)', f.name)
208
+ iteration = int(it.groups()[0]) if it else 0
209
+ self.task.get_logger().report_image(title=title,
210
+ - series=f.name.replace(it.group(), ''),
211
+ + # series=f.name.replace(it.group(), ''),
212
+ + series=f.name.replace(f"_batch{iteration}", ''),
213
+ local_path=str(f),
214
+ iteration=iteration)
215
+
dataset.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typer
2
+ import fiftyone as fo
3
+ from fiftyone import ViewField as F
4
+ from pathlib import Path
5
+ from pycocotools.coco import COCO
6
+ from loguru import logger
7
+ import cv2
8
+ import shutil
9
+ import os
10
+ import random
11
+ from collections import defaultdict
12
+ import csv
13
+
14
+
15
+ DEFAULT_EXCLUDE_NAME = "Ellen"
16
+ DEFAULT_INS_TRAIN = "instances_Train.json"
17
+ DEFAULT_INS_TEST = "instances_Test.json"
18
+
19
+ app = typer.Typer()
20
+
21
+
22
+ @app.command()
23
+ def newsplit(
24
+ in_dir: str,
25
+ train_json=DEFAULT_INS_TRAIN,
26
+ test_json=DEFAULT_INS_TEST,
27
+ exclude_name=DEFAULT_EXCLUDE_NAME,
28
+ ):
29
+ """
30
+ Merge the train and test datasets,
31
+ and then split them into new train/test by leaving one person out.
32
+ """
33
+
34
+ # load the dataset
35
+ logger.info("Loading datasets...")
36
+ ds1 = fo.Dataset.from_dir(
37
+ dataset_type=fo.types.COCODetectionDataset,
38
+ data_path=Path(in_dir) / "images",
39
+ labels_path=Path(in_dir) / "annotations" / train_json,
40
+ )
41
+ ds2 = fo.Dataset.from_dir(
42
+ dataset_type=fo.types.COCODetectionDataset,
43
+ data_path=Path(in_dir) / "images",
44
+ labels_path=Path(in_dir) / "annotations" / test_json,
45
+ )
46
+
47
+ logger.info(f"[Before] Num samples in train: {len(ds1)}")
48
+ logger.info(f"[Before] Num samples in test: {len(ds2)}")
49
+
50
+ # merge the datasets
51
+ ds1.merge_samples(ds2)
52
+
53
+ # generate the new split
54
+ logger.info(f"Excluding name in filepath as train set: {exclude_name}")
55
+ new_train_view = ds1.match(~F("filepath").re_match(exclude_name))
56
+ new_test_view = ds1.match(F("filepath").re_match(exclude_name))
57
+ assert len(new_train_view) + len(new_test_view) == len(ds1)
58
+ logger.info(f"[After] Num samples in train: {len(new_train_view)}")
59
+ logger.info(f"[After] Num samples in test: {len(new_test_view)}")
60
+ train_counts = new_train_view.count_values("detections.detections.label")
61
+ test_counts = new_test_view.count_values("detections.detections.label")
62
+ logger.info(f"[After] Train counts: {train_counts}")
63
+ logger.info(f"[After] Test counts: {test_counts}")
64
+
65
+ # export the new split
66
+ logger.info("Exporting new train/test...")
67
+ new_train_p = Path(in_dir) / "annotations" / f"new_train_no-{exclude_name}.json"
68
+ new_test_p = Path(in_dir) / "annotations" / f"new_test_{exclude_name}.json"
69
+ new_train_view.export(
70
+ dataset_type=fo.types.COCODetectionDataset,
71
+ labels_path=new_train_p,
72
+ label_field="segmentations",
73
+ classes=ds1.default_classes,
74
+ abs_paths=True,
75
+ )
76
+ new_test_view.export(
77
+ dataset_type=fo.types.COCODetectionDataset,
78
+ labels_path=new_test_p,
79
+ label_field="segmentations",
80
+ classes=ds2.default_classes,
81
+ abs_paths=True,
82
+ )
83
+ logger.info(f"Exported new train: {new_train_p}")
84
+ logger.info(f"Exported new test: {new_test_p}")
85
+
86
+
87
+ def _normalize(img_size, xy_s):
88
+ assert len(xy_s) % 2 == 0
89
+ normalized_xy_s = []
90
+ dw = 1.0 / (img_size[0])
91
+ dh = 1.0 / (img_size[1])
92
+ for i in range(len(xy_s)):
93
+ p = xy_s[i]
94
+ p = p * dw if i % 2 == 0 else p * dh
95
+ assert p <= 1.0 and p >= 0.0, f"{p} should < 1 and > 0"
96
+ normalized_xy_s.append(p)
97
+ return normalized_xy_s
98
+
99
+
100
+ def _coco2yolo(coco_img_dir, coco_json_path, out_dir, bbox_only=False, rois=None):
101
+ logger.info(f"Reading {Path(coco_json_path).name}...")
102
+ coco = COCO(coco_json_path)
103
+
104
+ cats = coco.loadCats(coco.getCatIds())
105
+ cats = sorted(cats, key=lambda x: x["id"], reverse=False)
106
+ assert cats[0]["id"] == 1, f"Assume cat id starts from 1, but got {cats[0]['id']}"
107
+ logger.info(f"{len(cats)} categories: {[cat['name'] for cat in cats]}")
108
+
109
+ img_ids = coco.getImgIds()
110
+ prefix = Path(coco_json_path).stem.split("_")[-1].lower() # either train or test
111
+
112
+ # create output directories
113
+ target_txt_r = Path(out_dir) / prefix / "labels"
114
+ target_img_r = Path(out_dir) / prefix / "images"
115
+ target_txt_r.mkdir(parents=True, exist_ok=False)
116
+ target_img_r.mkdir(parents=True, exist_ok=False)
117
+
118
+ logger.info(f"Num of imgs: {len(img_ids)}")
119
+
120
+ n_imgs_no_annos = 0
121
+ num_zero_area = 0
122
+ for img_id in img_ids:
123
+ img = coco.loadImgs(img_id)[0]
124
+ img_p = Path(coco_img_dir) / img["file_name"]
125
+ assert img_p.exists(), f"{img_p} does not exist"
126
+
127
+ anno_ids = coco.getAnnIds(imgIds=img["id"])
128
+ annos = coco.loadAnns(anno_ids)
129
+
130
+ new_filename = f"{img['id']}_{img_p.stem}"
131
+
132
+ out_img_p = target_img_r / (new_filename + img_p.suffix)
133
+
134
+ # get roi for the image if any
135
+ im_cv = cv2.imread(img_p.as_posix())
136
+ im_width, im_height = im_cv.shape[1], im_cv.shape[0]
137
+ roi = rois[(im_width, im_height)] if rois is not None else None
138
+ has_roi = (rois is not None) and (roi is not None) and len(roi) == 4
139
+ if not has_roi:
140
+ # copy image to target dir
141
+ shutil.copy(img_p, out_img_p)
142
+ else:
143
+ # crop the image to target dir
144
+ assert len(roi) == 4, f"ROI should have 4 values, but got {roi}"
145
+ cropped_img = im_cv[roi[1] : roi[1] + roi[3], roi[0] : roi[0] + roi[2]]
146
+ cv2.imwrite(out_img_p.as_posix(), cropped_img)
147
+
148
+ # bg imgs: only need to copy img, no need to create label file
149
+ if len(annos) == 0:
150
+ n_imgs_no_annos += 1
151
+ continue
152
+
153
+ # create the label txt file
154
+ txt_p = Path(target_txt_r) / (new_filename + ".txt")
155
+ if txt_p.exists():
156
+ logger.warning(f"{txt_p} already exists, {img_p} skipped")
157
+ txt_f = open(txt_p, "w")
158
+ img = cv2.imread(img_p.as_posix())
159
+ h, w, _ = img.shape
160
+
161
+ # generate txt file for each image
162
+ for ann in annos:
163
+ cls_id = ann["category_id"] - 1 # yolov5 uses zero-based class idx
164
+
165
+ # region bbox, for object detection
166
+ if bbox_only:
167
+ bbox = ann["bbox"]
168
+ # convert coco to yolo: top-x, top-y, w, h -> center-x, center-y, w, h
169
+ bbox_yolo = [
170
+ bbox[0] + bbox[2] / 2,
171
+ bbox[1] + bbox[3] / 2,
172
+ bbox[2],
173
+ bbox[3],
174
+ ]
175
+ n_bbox_p = " ".join([str(a) for a in _normalize((w, h), bbox_yolo)])
176
+ txt_f.write(f"{cls_id} {n_bbox_p}{os.linesep}")
177
+ continue
178
+ # endregion
179
+
180
+ # region seg, for instance segmentation
181
+ seg = ann["segmentation"]
182
+ if len(seg) > 1:
183
+ # TODO: Investigate why sometimes there are multiple segs
184
+ logger.warning(f"Skip {img_p} with {len(seg)} segs of {ann}")
185
+ continue
186
+
187
+ if len(seg) == 1:
188
+ xy_s = seg[0]
189
+ # handle roi if any
190
+ if has_roi:
191
+ xy_s = [xy - roi[i % 2] for i, xy in enumerate(xy_s)]
192
+ w, h = roi[2], roi[3]
193
+ # remove the points outside of roi
194
+ new_xy_s = []
195
+ for i in range(0, len(xy_s), 2):
196
+ x, y = xy_s[i], xy_s[i + 1]
197
+ if x >= 0 and x <= w and y >= 0 and y <= h:
198
+ new_xy_s.extend([x, y])
199
+ xy_s = new_xy_s
200
+ n_xy_s = _normalize((w, h), xy_s)
201
+ seg_p = " ".join([str(a) for a in n_xy_s])
202
+ txt_f.write(f"{cls_id} {seg_p}{os.linesep}")
203
+ # endregion
204
+
205
+ # region keypoint, for pose estimation
206
+ if "keypoints" in ann:
207
+ # skip area 0 keypoints which could cause yolov8 training error
208
+ if int(ann["area"]) == 0:
209
+ num_zero_area += 1
210
+ continue
211
+ kps = ann["keypoints"]
212
+ bbox = ann["bbox"]
213
+ # convert coco to yolo: top-x, top-y, w, h -> center-x, center-y, w, h
214
+ bbox_yolo = [
215
+ bbox[0] + bbox[2] / 2,
216
+ bbox[1] + bbox[3] / 2,
217
+ bbox[2],
218
+ bbox[3],
219
+ ]
220
+ n_bbox_p = " ".join([str(a) for a in _normalize((w, h), bbox_yolo)])
221
+ # normalize x,y of each keypoint and keep visibility as is
222
+ n_kp = []
223
+ for i in range(0, len(kps), 3):
224
+ n_kp.append(kps[i] / w)
225
+ n_kp.append(kps[i + 1] / h)
226
+ n_kp.append(kps[i + 2])
227
+ n_kp_p = " ".join([str(a) for a in n_kp])
228
+ txt_f.write(f"{cls_id} {n_bbox_p} {n_kp_p}{os.linesep}")
229
+ # endregion
230
+ txt_f.close()
231
+ # remove empty label file which has no annos
232
+ if txt_p.stat().st_size == 0:
233
+ txt_p.unlink()
234
+ n_imgs_no_annos += 1
235
+ empty_ratio = 100 * float(n_imgs_no_annos) / len(img_ids)
236
+ n_imgs_anns = len(img_ids) - n_imgs_no_annos
237
+ logger.info(f"# imgs w anns: {n_imgs_anns} {(100-empty_ratio):.2f}%")
238
+ logger.info(f"# imgs w/o anns: {n_imgs_no_annos} {empty_ratio:.2f}%")
239
+ logger.info(f"# zero area kps: {num_zero_area}")
240
+ txts = [f for f in target_txt_r.iterdir() if f.is_file()]
241
+ imgs = [f for f in target_img_r.iterdir() if f.is_file()]
242
+ assert (len(txts) + n_imgs_no_annos) == len(imgs) == len(img_ids)
243
+ return target_img_r
244
+
245
+
246
+ @app.command(help="Convert COCO dataset to YOLOv5 format")
247
+ def coco2yolov5(
248
+ in_dir: str,
249
+ out_dir: str,
250
+ split_val_ratio: float = 0.2,
251
+ seed: int = 42,
252
+ bbox_only: bool = False,
253
+ crop_roi_file: str = None,
254
+ ):
255
+ """
256
+ Convert COCO dataset to YOLOv5 format.
257
+ Support 3 task types: object detection, instance segmentation, pose estimation.
258
+
259
+ YOLOv5 seg labels are the same as detection labels, using txt files with one object per line.
260
+ The difference is that instead of "class, xywh" they are "class xy1, xy2, xy3,...".
261
+ Ref: https://github.com/ultralytics/yolov5/issues/10161#issuecomment-1315672357
262
+
263
+ YOLOv5 keypoint labels is using txt files with one object per line.
264
+ class cx cy w h x1 y1 v1 ... xn yn vn
265
+ All coordinates are normalized by image width and height.
266
+ vn (visibility): 0, 1, or 2 => not labeled, labeled but invisible, labeled and visible
267
+ Ref: https://github.com/ultralytics/ultralytics/issues/1870#issuecomment-1498909244
268
+ Example: https://ultralytics.com/assets/coco8-pose.zip
269
+ """
270
+ if Path(out_dir).exists():
271
+ delete = typer.confirm(f"{out_dir} alread exists. Are you sure to delete it?")
272
+ if not delete:
273
+ logger.info("Not deleting")
274
+ raise typer.Abort()
275
+ shutil.rmtree(out_dir)
276
+ logger.info(f"Deleted {Path(out_dir).name}")
277
+
278
+ ann_dir_p = Path(in_dir) / "annotations"
279
+ img_dir_p = Path(in_dir) / "images"
280
+ assert ann_dir_p.exists(), f"{ann_dir_p} does not exist"
281
+ assert img_dir_p.exists(), f"{img_dir_p} does not exist"
282
+
283
+ # try to find the json files of train & test in annotations dir
284
+ train_json_p = None
285
+ test_json_p = None
286
+ for f in ann_dir_p.iterdir():
287
+ if f.stem.lower().endswith("train"):
288
+ train_json_p = f
289
+ logger.info(f"Found train json: {f.name}")
290
+ elif f.stem.lower().endswith("test"):
291
+ test_json_p = f
292
+ logger.info(f"Found test json: {f.name}")
293
+ # must have train, while test is optional
294
+ assert train_json_p is not None, f"Cannot find train json in {ann_dir_p}"
295
+ do_split = False
296
+ if test_json_p is None:
297
+ logger.warning("Cannot find test json in [in_dir]/annotations")
298
+ do_split = typer.confirm("Do you want to split val from train?")
299
+
300
+ # region handle ROIs
301
+ rois = None
302
+ if crop_roi_file is not None:
303
+ roi_csv_p = Path(crop_roi_file)
304
+ assert roi_csv_p.exists(), f"{roi_csv_p} does not exist"
305
+ # read ROIs from csv, each image size should have one ROI
306
+ rois = defaultdict(lambda: [], {})
307
+ with open(roi_csv_p, "r") as f:
308
+ for roi in csv.DictReader(f):
309
+ ori_width = int(roi["ori_width"])
310
+ ori_height = int(roi["ori_height"])
311
+ roi_x = int(roi["roi_x"])
312
+ roi_y = int(roi["roi_y"])
313
+ roi_width = int(roi["roi_width"])
314
+ roi_height = int(roi["roi_height"])
315
+
316
+ key = (ori_width, ori_height)
317
+ assert key not in rois, f"Duplicate ROI for {key}"
318
+ rois[key] = [roi_x, roi_y, roi_width, roi_height]
319
+ # endregion
320
+
321
+ yolo_train_img_dir = None
322
+ yolo_test_img_dir = None
323
+ yolo_train_img_dir = _coco2yolo(img_dir_p, train_json_p, out_dir, bbox_only, rois)
324
+ if test_json_p is not None:
325
+ yolo_test_img_dir = _coco2yolo(img_dir_p, test_json_p, out_dir, bbox_only, rois)
326
+
327
+ if do_split:
328
+ yolo_test_img_dir = Path(out_dir) / "val" / "images"
329
+ # randomly select 20% of train images
330
+ train_imgs = [f for f in yolo_train_img_dir.iterdir() if f.is_file()]
331
+ n_test = int(len(train_imgs) * split_val_ratio)
332
+ logger.info(f"Split ratio {split_val_ratio}: {n_test} test images from train")
333
+ # set random seed to make sure the same images are selected
334
+ random.seed(seed)
335
+ test_imgs = random.sample(train_imgs, n_test)
336
+ # move test images to val/images
337
+ yolo_test_img_dir.mkdir(parents=True, exist_ok=True)
338
+ for f in test_imgs:
339
+ shutil.move(str(f), str(yolo_test_img_dir))
340
+ # move labels of test images to val/labels
341
+ yolo_test_label_dir = Path(out_dir) / "val" / "labels"
342
+ yolo_test_label_dir.mkdir(parents=True, exist_ok=True)
343
+ for f in test_imgs:
344
+ label_f = yolo_train_img_dir.parent / "labels" / f"{f.stem}.txt"
345
+ if label_f.exists():
346
+ shutil.move(str(label_f), str(yolo_test_label_dir))
347
+
348
+ # region create yaml file
349
+
350
+ logger.info(f"Reading {Path(train_json_p).name}...")
351
+ train_coco = COCO(train_json_p)
352
+ train_cats = train_coco.loadCats(train_coco.getCatIds())
353
+ num_kps = [
354
+ len(c["keypoints"])
355
+ for c in train_cats
356
+ if "keypoints" in c and len(c["keypoints"]) > 0
357
+ ]
358
+ # check if all categories have the same number of keypoints
359
+ if len(num_kps) > 0:
360
+ assert len(set(num_kps)) == 1, "Categories have different number of keypoints"
361
+ logger.info(f"Number of keypoints: {set(num_kps)}")
362
+ train_cats = [c["name"] for c in train_cats]
363
+ # ensure having the same categories in the json of train & test
364
+ # test_coco = COCO(test_json_p)
365
+ # test_cats = test_coco.loadCats(test_coco.getCatIds())
366
+ # test_cats = sorted(test_cats, key=lambda x: x["id"], reverse=False)
367
+ # test_cats = [c["name"] for c in test_cats]
368
+ # assert ",".join(train_cats) == ",".join(test_cats), "Categories mismatch"
369
+
370
+ out_config_file = Path(out_dir) / "data.yaml"
371
+ with open(out_config_file, "w") as f:
372
+ if len(num_kps) > 0:
373
+ f.write(f"kpt_shape: [{num_kps[0]},3]" + os.linesep)
374
+ assert num_kps[0] == 1, "Only support 1 keypoint for now"
375
+ f.write("flip_idx: [0]" + os.linesep)
376
+ f.write("names:" + os.linesep)
377
+ for c in train_cats:
378
+ f.write(f"- {c}" + os.linesep)
379
+ f.write(f"nc: {len(train_cats)}" + os.linesep)
380
+ f.write(f"path: {Path(out_dir).absolute()}" + os.linesep)
381
+ train_rel_path = f"{yolo_train_img_dir.parent.name}/{yolo_train_img_dir.name}"
382
+ f.write(f"train: {train_rel_path}" + os.linesep)
383
+ if yolo_test_img_dir is not None:
384
+ val_rel_path = f"{yolo_test_img_dir.parent.name}/{yolo_test_img_dir.name}"
385
+ f.write(f"val: {val_rel_path}" + os.linesep)
386
+
387
+ logger.info(f"Config file saved: {out_config_file}")
388
+ # endregion
389
+
390
+ logger.info("Done ✅")
391
+
392
+
393
+ @app.command(help="List all image sizes and counts in a directory recursively")
394
+ def list_img_sizes(
395
+ in_dir: str = typer.Argument(..., help="Input directory"),
396
+ ):
397
+ in_dir_p = Path(in_dir)
398
+ assert in_dir_p.exists(), f"{in_dir_p} does not exist"
399
+ assert in_dir_p.is_dir(), f"{in_dir_p} is not a directory"
400
+
401
+ ds = fo.Dataset.from_images_dir(in_dir_p)
402
+ ds.compute_metadata()
403
+
404
+ logger.info(f"Found {len(ds)} images in {in_dir_p}")
405
+
406
+ # count number of images for each size
407
+ sizes = defaultdict(lambda: 0, {})
408
+ for sample in ds:
409
+ metadata = sample.metadata
410
+ width = metadata.width
411
+ height = metadata.height
412
+ sizes[(width, height)] += 1
413
+ # sort with the most frequent size first
414
+ sizes = dict(sorted(sizes.items(), key=lambda x: x[1], reverse=True))
415
+ for k, v in sizes.items():
416
+ # find one example image for each size
417
+ sample = ds.match({"metadata.width": k[0], "metadata.height": k[1]}).first()
418
+ print(f"Size (w, h) {k}: {v} image(s), e.g., {sample.filepath}")
419
+
420
+
421
+ @app.command(help="Crop images in a directory recursively with ROIs from csv")
422
+ def crop_imgs(
423
+ in_dir: str = typer.Argument(..., help="Input directory"),
424
+ roi_csv: str = typer.Argument(..., help="CSV file containing ROIs"),
425
+ ):
426
+ in_dir_p = Path(in_dir)
427
+ assert in_dir_p.exists(), f"{in_dir_p} does not exist"
428
+ assert in_dir_p.is_dir(), f"{in_dir_p} is not a directory"
429
+
430
+ roi_csv_p = Path(roi_csv)
431
+ assert roi_csv_p.exists(), f"{roi_csv_p} does not exist"
432
+
433
+ # read ROIs from csv, each image size should have one ROI
434
+ rois = defaultdict(lambda: [], {})
435
+ with open(roi_csv_p, "r") as f:
436
+ for roi in csv.DictReader(f):
437
+ ori_width = int(roi["ori_width"])
438
+ ori_height = int(roi["ori_height"])
439
+ roi_x = int(roi["roi_x"])
440
+ roi_y = int(roi["roi_y"])
441
+ roi_width = int(roi["roi_width"])
442
+ roi_height = int(roi["roi_height"])
443
+
444
+ key = (ori_width, ori_height)
445
+ assert key not in rois, f"Duplicate ROI for {key}"
446
+ rois[key] = [roi_x, roi_y, roi_width, roi_height]
447
+
448
+ # read and crop images
449
+ # write the cropped images to a new directory
450
+ out_dir_p = in_dir_p.parent / f"{in_dir_p.name}_cropped"
451
+ Path(out_dir_p).mkdir(parents=True, exist_ok=True)
452
+ ds = fo.Dataset.from_images_dir(in_dir_p)
453
+ logger.info(f"Found {len(ds)} images in {in_dir_p}")
454
+ for sample in ds:
455
+ img_path = sample.filepath
456
+
457
+ # read and crop the image
458
+ img = cv2.imread(img_path)
459
+ width, height = img.shape[1], img.shape[0]
460
+ roi = rois[(width, height)]
461
+ cropped_img = img[roi[1] : roi[1] + roi[3], roi[0] : roi[0] + roi[2]]
462
+
463
+ # keep the original folder structure
464
+ out_img_p = out_dir_p / Path(img_path).relative_to(in_dir_p.absolute())
465
+ # create the subfolder if not exist
466
+ if not out_img_p.parent.exists():
467
+ out_img_p.parent.mkdir(parents=True, exist_ok=True)
468
+ cv2.imwrite(str(out_img_p), cropped_img)
469
+ logger.info(f"Cropped images saved to {out_dir_p}")
470
+
471
+
472
+ @app.command(help="Count num of images without aorta annotations")
473
+ def count_n_imgs_no_aorta(
474
+ in_coco_json_p: str = typer.Argument(..., help="Input coco json file"),
475
+ aorta_cat_name: str = typer.Argument("aorta", help="Name of aorta category"),
476
+ ):
477
+ logger.info(f"Reading {Path(in_coco_json_p).name}...")
478
+ assert Path(in_coco_json_p).exists(), f"{in_coco_json_p} does not exist"
479
+ coco = COCO(in_coco_json_p)
480
+
481
+ cats = coco.loadCats(coco.getCatIds())
482
+ cats = sorted(cats, key=lambda x: x["id"], reverse=False)
483
+ # find the category id of aorta
484
+ aorta_cat_id = None
485
+ for cat in cats:
486
+ if cat["name"] == aorta_cat_name:
487
+ aorta_cat_id = cat["id"]
488
+ break
489
+ assert aorta_cat_id is not None, f"Cannot find {aorta_cat_name} in {in_coco_json_p}"
490
+ logger.info(f"Found {aorta_cat_name} with id {aorta_cat_id}")
491
+
492
+ n_img_no_aorta = 0
493
+ for img_id in coco.getImgIds():
494
+ anno_ids = coco.getAnnIds(imgIds=img_id)
495
+ annos = coco.loadAnns(anno_ids)
496
+ has_aorta = False
497
+ for anno in annos:
498
+ if anno["category_id"] == aorta_cat_id:
499
+ has_aorta = True
500
+ break
501
+ if not has_aorta:
502
+ n_img_no_aorta += 1
503
+ logger.info(f"Found {n_img_no_aorta} images without {aorta_cat_name}")
504
+
505
+
506
+ @app.command(help="Remove non-aorta annotations from a YOLOv5 dataset")
507
+ def keep_only_aorta_labels_in_yolo(
508
+ in_dir: str = typer.Argument(..., help="Input label directory"),
509
+ aorta_class_id: int = typer.Argument(0, help="Class id of aorta"),
510
+ ):
511
+ txts = list(Path(in_dir).glob("*.txt"))
512
+ logger.info(f"Found {len(txts)} txt files in {in_dir}")
513
+ for txt_p in txts:
514
+ ori_lines, new_lines = [], []
515
+ with open(txt_p, "r") as f:
516
+ ori_lines = f.readlines()
517
+ for line in ori_lines:
518
+ nums = line.split(" ")
519
+ if int(nums[0]) == aorta_class_id:
520
+ new_lines.append(line)
521
+ with open(txt_p, "w") as new_f:
522
+ new_f.writelines(new_lines)
523
+
524
+
525
+ if __name__ == "__main__":
526
+ app()
demo.bat ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ REM "Please change the path to your own path"
4
+ cd "C:\Users\chenp\Downloads\aorta_demo_v3"
5
+
6
+ REM "Please change the path to your own path"
7
+ call C:\ProgramData\miniconda3\Scripts\activate.bat
8
+
9
+ call conda activate echo
10
+ call python demo.py --device GPU --jobs 2
11
+
12
+ pause
demo.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+ import time
4
+ from time import perf_counter
5
+ import argparse
6
+ from loguru import logger
7
+ import os
8
+
9
+ from predict import Model
10
+
11
+ from datetime import datetime
12
+ from scipy import signal
13
+ import plotly.graph_objects as go
14
+ import numpy as np
15
+ import io
16
+ from PIL import Image
17
+
18
+ import cv2
19
+ from PySide6.QtCore import Qt, QThread, Signal, Slot
20
+ from PySide6.QtGui import QImage, QPixmap
21
+ from PySide6.QtWidgets import (
22
+ QApplication,
23
+ QHBoxLayout,
24
+ QLabel,
25
+ QMainWindow,
26
+ QPushButton,
27
+ QSizePolicy,
28
+ QVBoxLayout,
29
+ QWidget,
30
+ )
31
+
32
+ # for telemed
33
+ import matplotlib.pyplot as plt
34
+ import ctypes
35
+ from ctypes import *
36
+
37
+ # 720p
38
+ video_w = 1280
39
+ video_h = 720
40
+
41
+
42
+ # Copy from detection.py from telemed sample code
43
+ class Telemed:
44
+ def __init__(self):
45
+ # starting copy from the origianl main
46
+
47
+ # Setting ultrasound size
48
+ # w = 512
49
+ # h = 512
50
+ w = 640
51
+ h = 640
52
+
53
+ # Load dll
54
+ # usgfw2 = cdll.LoadLibrary('./usgfw2wrapper_C++_sources/usgfw2wrapper/x64/Release/usgfw2wrapper.dll')
55
+ usgfw2 = cdll.LoadLibrary("./usgfw2wrapper.dll")
56
+
57
+ # Ultrasound initialize
58
+ usgfw2.on_init()
59
+ ERR = usgfw2.init_ultrasound_usgfw2()
60
+
61
+ # Check probe
62
+ if ERR == 2:
63
+ logger.error("Main Usgfw2 library object not created")
64
+ usgfw2.Close_and_release()
65
+ sys.exit()
66
+
67
+ ERR = usgfw2.find_connected_probe()
68
+
69
+ if ERR != 101:
70
+ logger.error("Probe not detected")
71
+ usgfw2.Close_and_release()
72
+ sys.exit()
73
+
74
+ ERR = usgfw2.data_view_function()
75
+
76
+ if ERR < 0:
77
+ logger.error(
78
+ "Main ultrasound scanning object for selected probe not created"
79
+ )
80
+ sys.exit()
81
+
82
+ ERR = usgfw2.mixer_control_function(0, 0, w, h, 0, 0, 0)
83
+ if ERR < 0:
84
+ logger.error("B mixer control not returned")
85
+ sys.exit()
86
+
87
+ # Probe setting
88
+ res_X = ctypes.c_float(0.0)
89
+ res_Y = ctypes.c_float(0.0)
90
+ usgfw2.get_resolution(ctypes.pointer(res_X), ctypes.pointer(res_Y))
91
+
92
+ X_axis = np.zeros(shape=(w))
93
+ Y_axis = np.zeros(shape=(h))
94
+ if w % 2 == 0:
95
+ k = 0
96
+ for i in range(-w // 2, w // 2 + 1):
97
+ if i < 0:
98
+ j = i + 0.5
99
+ X_axis[k] = j * res_X.value
100
+ k = k + 1
101
+ else:
102
+ if i > 0:
103
+ j = i - 0.5
104
+ X_axis[k] = j * res_X.value
105
+ k = k + 1
106
+
107
+ else:
108
+ for i in range(-w // 2, w // 2):
109
+ X_axis[i + w / 2 + 1] = i * res_X.value
110
+
111
+ for i in range(0, h - 1):
112
+ Y_axis[i] = i * res_Y.value
113
+
114
+ old_resolution_x = res_X.value
115
+ old_resolution_y = res_X.value
116
+
117
+ # Image setting
118
+ p_array = (ctypes.c_uint * w * h * 4)()
119
+
120
+ fig, ax = plt.subplots()
121
+ usgfw2.return_pixel_values(ctypes.pointer(p_array))
122
+ buffer_as_numpy_array = np.frombuffer(p_array, np.uint)
123
+ reshaped_array = np.reshape(buffer_as_numpy_array, (w, h, 4))
124
+
125
+ img = ax.imshow(
126
+ reshaped_array[:, :, 0:3],
127
+ cmap="gray",
128
+ vmin=0,
129
+ vmax=255,
130
+ origin="lower",
131
+ extent=[np.amin(X_axis), np.amax(X_axis), np.amax(Y_axis), np.amin(Y_axis)],
132
+ )
133
+
134
+ # starting copy from the original __int__
135
+ self.w = w
136
+ self.h = h
137
+
138
+ (
139
+ self.usgfw2,
140
+ self.p_array,
141
+ self.res_X,
142
+ self.res_Y,
143
+ self.old_resolution_x,
144
+ self.old_resolution_y,
145
+ self.X_axis,
146
+ self.Y_axis,
147
+ self.img,
148
+ ) = (
149
+ usgfw2,
150
+ p_array,
151
+ res_X,
152
+ res_Y,
153
+ old_resolution_x,
154
+ old_resolution_y,
155
+ X_axis,
156
+ Y_axis,
157
+ img,
158
+ )
159
+
160
+ # return the image from telemed
161
+ def imaging(self):
162
+ self.usgfw2.return_pixel_values(ctypes.pointer(self.p_array))
163
+ buffer_as_numpy_array = np.frombuffer(self.p_array, np.uint)
164
+ reshaped_array = np.reshape(buffer_as_numpy_array, (self.w, self.h, 4))
165
+
166
+ self.usgfw2.get_resolution(
167
+ ctypes.pointer(self.res_X), ctypes.pointer(self.res_Y)
168
+ )
169
+ if (
170
+ self.res_X.value != self.old_resolution_x
171
+ or self.res_Y.value != self.old_resolution_y
172
+ ):
173
+ if self.w % 2 == 0:
174
+ k = 0
175
+ for i in range(-self.w // 2, self.w // 2 + 1):
176
+ if i < 0:
177
+ j = i + 0.5
178
+ self.X_axis[k] = j * self.res_X.value
179
+ k = k + 1
180
+ else:
181
+ if i > 0:
182
+ j = i - 0.5
183
+ self.X_axis[k] = j * self.res_X.value
184
+ k = k + 1
185
+ else:
186
+ for i in range(-self.w // 2, self.w // 2):
187
+ self.X_axis[i + self.w / 2 + 1] = i * self.res_X.value
188
+
189
+ for i in range(0, self.h - 1):
190
+ self.Y_axis[i] = i * self.res_Y.value
191
+
192
+ self.old_resolution_x = self.res_X.value
193
+ self.old_resolution_y = self.res_X.value
194
+
195
+ self.img.set_data(reshaped_array[:, :, 0:3])
196
+ self.img.set_extent(
197
+ [
198
+ np.amin(self.X_axis),
199
+ np.amax(self.X_axis),
200
+ np.amax(self.Y_axis),
201
+ np.amin(self.Y_axis),
202
+ ]
203
+ )
204
+
205
+ # Transfer image format to cv2
206
+ img_array = np.asarray(self.img.get_array())
207
+ img_array = img_array[::-1, :, ::-1] # format same as plt image, RBG to BGR
208
+ return img_array
209
+
210
+
211
+ class Thread(QThread):
212
+ updateFrame = Signal(QImage)
213
+
214
+ def __init__(self, parent=None, args=None):
215
+ QThread.__init__(self, parent)
216
+ self.status = True
217
+ self.cap = True
218
+ self.args = args
219
+
220
+ # init telemed
221
+ if args.video is None:
222
+ self.telemed = Telemed()
223
+
224
+ # init model
225
+ is_async = (
226
+ True if self.args.jobs == "auto" or int(self.args.jobs) > 1 else False
227
+ )
228
+ self.model = Model(
229
+ model_path=self.args.model,
230
+ imgsz=self.args.img_size,
231
+ classes=self.args.classes,
232
+ device=self.args.device,
233
+ plot_mask=self.args.plot_mask,
234
+ conf_thres=self.args.conf_thres,
235
+ is_async=is_async,
236
+ n_jobs=self.args.jobs,
237
+ )
238
+
239
+ def get_stats_fig(self, aorta_widths, aorta_confs, fig_w, fig_h, ts):
240
+ title_font_size = 28
241
+ body_font_size = 24
242
+ img_quality = 100 * np.mean(aorta_confs)
243
+ avg_width = np.mean(aorta_widths)
244
+ max_width = np.max(aorta_widths)
245
+ suggestions = [
246
+ "N/A, within normal limit",
247
+ "Follow up in 5 years",
248
+ "Make an appointment as soon as possible",
249
+ ]
250
+ s = None
251
+ if avg_width < 3:
252
+ s = suggestions[0]
253
+ elif avg_width < 5:
254
+ s = suggestions[1]
255
+ else:
256
+ s = suggestions[2]
257
+
258
+ # region smoothing: method 2, keep the peaks
259
+ # peaks = signal.find_peaks(aorta_widths, height=0.5, distance=40)
260
+ # new_y = []
261
+ # # smooth the values between the peaks
262
+ # start = 0
263
+ # end = peaks[0][0]
264
+ # new_y.extend(signal.savgol_filter(aorta_widths[start:end], end - start, 2))
265
+ # for i in range(len(peaks[0]) - 1):
266
+ # start = peaks[0][i] + 1
267
+ # end = peaks[0][i + 1]
268
+ # new_y.append(aorta_widths[peaks[0][i]]) # add peak value
269
+ # new_y.extend(
270
+ # signal.savgol_filter(
271
+ # aorta_widths[start:end],
272
+ # end - start, # window size used for filtering
273
+ # 2,
274
+ # )
275
+ # ) # order of fitted polynomial
276
+ # # add the last peak
277
+ # new_y.append(aorta_widths[peaks[0][-1]])
278
+ # start = peaks[0][-1] + 1
279
+ # end = len(aorta_widths)
280
+ # new_y.extend(signal.savgol_filter(aorta_widths[start:end], end - start, 2))
281
+ # endregion
282
+
283
+ # region smoothing: method 1, do not keep the peaks
284
+ window_size = 53
285
+ if len(aorta_widths) < window_size:
286
+ window_size = len(aorta_widths) - 1
287
+ new_y = signal.savgol_filter(aorta_widths, window_size, 3)
288
+ # endregion
289
+
290
+ x = np.arange(1, len(aorta_widths) + 1, dtype=int)
291
+
292
+ fig = go.Figure()
293
+ fig.add_trace(
294
+ go.Scatter(
295
+ x=x, y=aorta_widths, mode="lines", line=dict(color="royalblue", width=1)
296
+ )
297
+ )
298
+ fig.add_trace(
299
+ go.Scatter(
300
+ x=x,
301
+ y=new_y,
302
+ mode="lines",
303
+ marker=dict(
304
+ size=3,
305
+ color="mediumpurple",
306
+ ),
307
+ )
308
+ )
309
+ fig.update_layout(
310
+ autosize=False,
311
+ width=fig_w,
312
+ height=fig_h,
313
+ margin=dict(l=50, r=50, b=50, t=400, pad=4),
314
+ paper_bgcolor="LightSteelBlue",
315
+ showlegend=False,
316
+ )
317
+ fig.add_annotation(
318
+ text=f"max={max_width:.2f} cm",
319
+ x=np.argmax(aorta_widths),
320
+ y=np.max(aorta_widths),
321
+ xref="x",
322
+ yref="y",
323
+ showarrow=True,
324
+ font=dict(color="#ffffff"),
325
+ arrowhead=2,
326
+ arrowsize=1,
327
+ arrowwidth=2,
328
+ borderpad=4,
329
+ bgcolor="#ff7f0e",
330
+ opacity=0.8,
331
+ )
332
+ fig.add_annotation(
333
+ text=f"smoothed max={np.max(new_y):.2f} cm",
334
+ x=np.argmax(new_y),
335
+ y=np.max(new_y),
336
+ xref="x",
337
+ yref="y",
338
+ showarrow=True,
339
+ font=dict(color="#ffffff"),
340
+ arrowhead=2,
341
+ arrowsize=1,
342
+ arrowwidth=2,
343
+ ax=-100,
344
+ ay=-50,
345
+ borderpad=4,
346
+ bgcolor="#ff7f0e",
347
+ opacity=0.8,
348
+ )
349
+ fig.add_annotation(
350
+ text="<b>Report of Abdominal Aorta Examination</b>",
351
+ xref="paper",
352
+ yref="paper",
353
+ x=0.5,
354
+ y=2.3,
355
+ showarrow=False,
356
+ font=dict(size=title_font_size),
357
+ )
358
+ fig.add_annotation(
359
+ text=f"Image acquisition quality: {img_quality:.0f}%",
360
+ xref="paper",
361
+ yref="paper",
362
+ x=0,
363
+ y=2.0,
364
+ showarrow=False,
365
+ font=dict(size=body_font_size),
366
+ )
367
+ fig.add_annotation(
368
+ text=f"Aorta Maximal Width: {max_width:.2f} cm",
369
+ xref="paper",
370
+ yref="paper",
371
+ x=0,
372
+ y=1.8,
373
+ showarrow=False,
374
+ font=dict(size=body_font_size),
375
+ )
376
+ fig.add_annotation(
377
+ text=f"Aorta Maximal Width (Smoothed): {np.max(new_y):.2f} cm",
378
+ xref="paper",
379
+ yref="paper",
380
+ x=0,
381
+ y=1.6,
382
+ showarrow=False,
383
+ font=dict(size=body_font_size),
384
+ )
385
+ fig.add_annotation(
386
+ text=f"Average: {avg_width:.2f} cm",
387
+ xref="paper",
388
+ yref="paper",
389
+ x=0,
390
+ y=1.4,
391
+ showarrow=False,
392
+ font=dict(size=body_font_size),
393
+ )
394
+ fig.add_annotation(
395
+ text=f"Suggestion: {s}",
396
+ xref="paper",
397
+ yref="paper",
398
+ x=0,
399
+ y=1.2,
400
+ showarrow=False,
401
+ font=dict(size=body_font_size),
402
+ )
403
+ fig.add_annotation(
404
+ text=f"Generated at {ts}",
405
+ xref="paper",
406
+ yref="paper",
407
+ x=1,
408
+ y=1,
409
+ showarrow=False,
410
+ )
411
+ return fig
412
+
413
+ def run(self):
414
+ one_cm_in_pixels = 48 # hard-coded
415
+ aorta_cm_thre1 = 3
416
+ aorta_cm_thre2 = 5
417
+ black = (0, 0, 0)
418
+ white = (255, 255, 255)
419
+ red = (0, 0, 255)
420
+ green = (0, 255, 0)
421
+
422
+ aorta_widths_stats = [0, 0, 0] # three ranges: <3, 3-5, >5
423
+ aorta_widths = []
424
+ aorta_confs = []
425
+
426
+ expected_fps = None
427
+ frame_count = None
428
+ frame_w = None
429
+ frame_h = None
430
+ if self.args.video:
431
+ self.cap = cv2.VideoCapture(self.args.video)
432
+ expected_fps = self.cap.get(cv2.CAP_PROP_FPS)
433
+ secs_per_frame = 1 / expected_fps
434
+ frame_w, frame_h = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
435
+ self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
436
+ )
437
+ frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
438
+ logger.info(f"Video source FPS: {expected_fps}")
439
+ logger.info(f"Milliseconds per frame: {secs_per_frame}")
440
+ logger.info(f"Video source resolution (WxH): {frame_w}x{frame_h}")
441
+ logger.info(f"Video source frame count: {frame_count}")
442
+ assert frame_count > 0, "No frame found"
443
+
444
+ n_read_frames = 0
445
+ next_frame_to_infer = 0
446
+ next_frame_to_show = 0
447
+ n_repeat_failure = 0
448
+ is_last_failed = False
449
+ start_time = perf_counter()
450
+ while self.status:
451
+ frame = None
452
+
453
+ # avoid infinite loop
454
+ if n_repeat_failure > 30:
455
+ break
456
+
457
+ # inference
458
+ color_frame, others, results, xyxy, conf = None, None, None, None, None
459
+ if self.model.is_async:
460
+ results = self.model.get_result(next_frame_to_show)
461
+ if results:
462
+ color_frame, others = results
463
+ xyxy, conf, _ = others
464
+ next_frame_to_show += 1
465
+
466
+ if self.model.is_async and self.model.is_free_to_infer_async():
467
+ if self.args.video:
468
+ ret, frame = self.cap.read()
469
+
470
+ if not ret:
471
+ n_repeat_failure += 1 if is_last_failed else 0
472
+ is_last_failed = True
473
+ continue
474
+ else:
475
+ # read the frame from telemed
476
+ # TODO(martin): Check read failure
477
+ frame = self.telemed.imaging()
478
+
479
+ n_read_frames += 1
480
+ self.model.predict_async(frame, next_frame_to_infer)
481
+ next_frame_to_infer += 1
482
+ elif not self.model.is_async:
483
+ if self.args.video:
484
+ ret, frame = self.cap.read()
485
+ if not ret:
486
+ n_repeat_failure += 1 if is_last_failed else 0
487
+ is_last_failed = True
488
+ continue
489
+ else:
490
+ # read the frame from telemed
491
+ # TODO(martin): Check read failure
492
+ frame = self.telemed.imaging()
493
+
494
+ n_read_frames += 1
495
+ results = self.model.predict(frame)
496
+ color_frame, others = results
497
+ xyxy, conf, _ = others # bbox and confidence
498
+ if results is None:
499
+ continue
500
+
501
+ is_last_failed = False
502
+
503
+ # check if aorta is within the ROI box, and draw the box
504
+ aorta_width_in_cm = 0
505
+ is_found = xyxy is not None
506
+ is_in_box = False
507
+ is_too_left, is_too_right = False, False
508
+ w, h = color_frame.shape[1], color_frame.shape[0]
509
+ box_w = int(w * 0.1)
510
+ box_h = int(h * 0.5)
511
+ box_top_left = (w // 2 - box_w // 2, h // 4)
512
+ box_bottom_right = (w // 2 + box_w // 2, h // 4 + box_h)
513
+ if xyxy is not None:
514
+ x1, y1, x2, y2 = xyxy
515
+
516
+ # check aorta width
517
+ aorta_width_in_cm = (x2 - x1) / one_cm_in_pixels
518
+ aorta_widths.append(aorta_width_in_cm)
519
+ aorta_confs.append(conf)
520
+ if aorta_width_in_cm < aorta_cm_thre1:
521
+ aorta_widths_stats[0] += 1
522
+ elif aorta_width_in_cm < aorta_cm_thre2:
523
+ aorta_widths_stats[1] += 1
524
+ else:
525
+ aorta_widths_stats[2] += 1
526
+
527
+ # check whether aorta is in the box
528
+ if (
529
+ x1 > box_top_left[0]
530
+ and x2 < box_bottom_right[0]
531
+ and y1 > box_top_left[1]
532
+ and y2 < box_bottom_right[1]
533
+ ):
534
+ is_in_box = True
535
+ is_too_right = x2 > box_bottom_right[0]
536
+ is_too_left = x1 < box_top_left[0]
537
+
538
+ # plot ROI box with color status
539
+ box_color = green if is_in_box else red
540
+ color_frame = cv2.rectangle(
541
+ color_frame, box_top_left, box_bottom_right, box_color, 2
542
+ )
543
+ assert not (
544
+ is_too_left and is_too_right
545
+ ), "Cannot be both too left and too right"
546
+ if is_too_left:
547
+ start_p = (box_top_left[0], int(h * 0.9))
548
+ end_p = (box_bottom_right[0], int(h * 0.9))
549
+ cv2.arrowedLine(color_frame, start_p, end_p, red, 3)
550
+ if is_too_right:
551
+ start_p = (box_bottom_right[0], int(h * 0.9))
552
+ end_p = (box_top_left[0], int(h * 0.9))
553
+ cv2.arrowedLine(color_frame, start_p, end_p, red, 3)
554
+ if is_in_box:
555
+ cv2.putText(
556
+ color_frame,
557
+ "GOOD",
558
+ (box_top_left[0], int(h * 0.9)),
559
+ cv2.FONT_HERSHEY_SIMPLEX,
560
+ 1,
561
+ green,
562
+ 3,
563
+ )
564
+
565
+ # plot aorta width
566
+ text = (
567
+ f"Aorta width: {aorta_width_in_cm:.2f} cm"
568
+ if is_found
569
+ else "Aorta width: N/A"
570
+ )
571
+ cv2.putText(
572
+ color_frame, text, (50, 90), cv2.FONT_HERSHEY_SIMPLEX, 1, white, 3
573
+ )
574
+
575
+ # region FPS
576
+ fps = None
577
+ if n_read_frames > 0:
578
+ fps = n_read_frames / (perf_counter() - start_time)
579
+
580
+ # Slow down the loop if FPS is too high
581
+ if self.args.sync:
582
+ while fps > expected_fps:
583
+ time.sleep(0.001)
584
+ fps = n_read_frames / (perf_counter() - start_time)
585
+
586
+ cv2.putText(
587
+ color_frame,
588
+ f"FPS: {fps:.2f}",
589
+ (50, 30),
590
+ cv2.FONT_HERSHEY_SIMPLEX,
591
+ 1,
592
+ white,
593
+ 3,
594
+ )
595
+ # endregion
596
+
597
+ # Creating and scaling QImage
598
+ h, w, ch = color_frame.shape
599
+ img = QImage(color_frame.data, w, h, ch * w, QImage.Format_BGR888)
600
+ scaled_img = img.scaled(video_w, video_h, Qt.KeepAspectRatio)
601
+
602
+ # Emit signal
603
+ self.updateFrame.emit(scaled_img)
604
+
605
+ if self.args.video:
606
+ progress = 100 * n_read_frames / frame_count
607
+ fps_msg = f", FPS: {fps:.2f}" if fps is not None else ""
608
+ print(
609
+ f"Processed {n_read_frames}/{frame_count} ({progress:.2f}%) frames"
610
+ + fps_msg,
611
+ end="\r" if n_read_frames < frame_count else os.linesep,
612
+ )
613
+ if n_read_frames >= frame_count:
614
+ logger.info("Finished processing video")
615
+ break
616
+ if self.args.video:
617
+ self.cap.release()
618
+
619
+ if not self.status:
620
+ logger.info("Stopped by user")
621
+ return
622
+
623
+ # draw a black image with frame width & height
624
+ # with some text in center indicating generating report
625
+ # it's just a dummy step to make demo more real
626
+ im = np.zeros((frame_h, frame_w, 3), np.uint8)
627
+ cv2.putText(
628
+ im,
629
+ "Generating report for you...",
630
+ (frame_w // 3, frame_h // 2),
631
+ cv2.FONT_HERSHEY_SIMPLEX,
632
+ 1,
633
+ white,
634
+ 3,
635
+ )
636
+ img = QImage(im.data, frame_w, frame_h, ch * w, QImage.Format_BGR888)
637
+ scaled_img = img.scaled(video_w, video_h, Qt.KeepAspectRatio)
638
+ self.updateFrame.emit(scaled_img)
639
+ time.sleep(3)
640
+
641
+ # plot aorta width tracing line chart
642
+ now_t = datetime.now()
643
+ ts1 = now_t.strftime("%Y%m%d_%H%M%S")
644
+ ts2 = now_t.strftime("%Y/%m/%d %I:%M:%S")
645
+ Path("runs").mkdir(parents=True, exist_ok=True)
646
+ # np.save("runs/aorta_widths.npy", aorta_widths)
647
+ fig_out_p = f"runs/aorta_report_{ts1}.jpeg"
648
+ fig = self.get_stats_fig(aorta_widths, aorta_confs, video_w, video_h, ts2)
649
+
650
+ # This may hang under Windows: https://github.com/plotly/Kaleido/issues/110
651
+ # The workaround is to install older kaleido version (see requirements.txt)
652
+ fig.write_image(fig_out_p)
653
+
654
+ logger.info(f"Saved aorta report: {fig_out_p}")
655
+ img_bytes = fig.to_image(format="jpg", width=video_w, height=video_h)
656
+ line_chart = np.array(Image.open(io.BytesIO(img_bytes)))
657
+ line_chart = cv2.cvtColor(line_chart, cv2.COLOR_RGB2BGR)
658
+ h, w, ch = line_chart.shape
659
+ img = QImage(line_chart.data, video_w, video_h, ch * w, QImage.Format_BGR888)
660
+ scaled_img = img.scaled(w, h, Qt.KeepAspectRatio)
661
+ # Emit signal
662
+ self.updateFrame.emit(scaled_img)
663
+ time.sleep(5)
664
+
665
+ # keep report open until user closes the window
666
+ while self.status and not self.args.exit_on_end:
667
+ time.sleep(0.1)
668
+
669
+
670
+ class Window(QMainWindow):
671
+ def __init__(self, args=None):
672
+ super().__init__()
673
+ # Title and dimensions
674
+ self.setWindowTitle("Demo")
675
+ self.setGeometry(0, 0, 800, 500)
676
+
677
+ # Create a label for the display camera
678
+ self.label = QLabel(self)
679
+ # self.label.setFixedSize(self.width(), self.height())
680
+ self.label.setFixedSize(video_w, video_h)
681
+
682
+ # Thread in charge of updating the image
683
+ self.th = Thread(self, args)
684
+ self.th.finished.connect(self.close)
685
+ self.th.updateFrame.connect(self.setImage)
686
+
687
+ # Buttons layout
688
+ buttons_layout = QHBoxLayout()
689
+ self.button1 = QPushButton("Start")
690
+ self.button2 = QPushButton("Stop/Close")
691
+ self.button1.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
692
+ self.button2.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
693
+ buttons_layout.addWidget(self.button2)
694
+ buttons_layout.addWidget(self.button1)
695
+
696
+ right_layout = QHBoxLayout()
697
+ # right_layout.addWidget(self.group_model, 1)
698
+ right_layout.addLayout(buttons_layout, 1)
699
+
700
+ # Main layout
701
+ layout = QVBoxLayout()
702
+ layout.addWidget(self.label)
703
+ layout.addLayout(right_layout)
704
+
705
+ # Central widget
706
+ widget = QWidget(self)
707
+ widget.setLayout(layout)
708
+ self.setCentralWidget(widget)
709
+
710
+ # Connections
711
+ self.button1.clicked.connect(self.start)
712
+ self.button2.clicked.connect(self.kill_thread)
713
+ self.button2.setEnabled(False)
714
+
715
+ if args.start_on_open:
716
+ # start thread
717
+ self.start()
718
+
719
+ @Slot()
720
+ def kill_thread(self):
721
+ logger.info("Finishing...")
722
+ self.th.status = False
723
+ time.sleep(1)
724
+ # Give time for the thread to finish
725
+ self.button2.setEnabled(False)
726
+ self.button1.setEnabled(True)
727
+ cv2.destroyAllWindows()
728
+ self.th.exit()
729
+ # Give time for the thread to finish
730
+ time.sleep(1)
731
+
732
+ @Slot()
733
+ def start(self):
734
+ logger.info("Starting...")
735
+ self.button2.setEnabled(True)
736
+ self.button1.setEnabled(False)
737
+ self.th.start()
738
+ logger.info("Thread started")
739
+
740
+ @Slot(QImage)
741
+ def setImage(self, image):
742
+ self.label.setPixmap(QPixmap.fromImage(image))
743
+
744
+
745
+ if __name__ == "__main__":
746
+ # get user inputs using argparse
747
+ parser = argparse.ArgumentParser()
748
+ parser.add_argument(
749
+ "--video",
750
+ type=str,
751
+ default=None,
752
+ help="path to video file, if None (default) would read from telemed",
753
+ )
754
+ parser.add_argument(
755
+ "--model",
756
+ type=str,
757
+ default="best_openvino_model/best.xml",
758
+ help="path to model file",
759
+ )
760
+ parser.add_argument("--img-size", type=int, default=640, help="image size")
761
+ parser.add_argument(
762
+ "--classes", nargs="+", type=int, default=[0], help="filter by class"
763
+ )
764
+ parser.add_argument("--device", type=str, default="CPU", help="device to use")
765
+ parser.add_argument("--sync", action="store_true", help="sync video FPS")
766
+ parser.add_argument("--plot-mask", action="store_true", help="plot mask")
767
+ parser.add_argument("--conf-thres", type=float, default=0.25, help="conf thresh")
768
+ parser.add_argument("--jobs", type=str, default=1, help="num of jobs, async if > 1")
769
+ parser.add_argument("--start-on-open", action="store_true", help="start on open")
770
+ parser.add_argument("--exit-on-end", action="store_true", help="exit if video ends")
771
+ args = parser.parse_args()
772
+ assert (
773
+ args.jobs == "auto" or int(args.jobs) > 0
774
+ ), f"--jobs must be > 0 or auto, got {args.jobs}"
775
+ if args.video:
776
+ assert Path(args.video).exists(), f"Video file {args.video} not found"
777
+ assert Path(args.model).exists(), f"Model file {args.model} not found"
778
+ app = QApplication()
779
+ w = Window(args)
780
+ w.show()
781
+ sys.exit(app.exec())
demo_headless.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ VIRTUAL_DISPLAY_NUM=99
4
+
5
+ OUTPUT_VIDEO="runs/demo_recording_$(date +"%Y-%m-%d_%H-%M-%S").mp4"
6
+
7
+ # start xvfb server
8
+ Xvfb :$VIRTUAL_DISPLAY_NUM -screen 0 1280x720x24 > /dev/null & XVFB_PID=$!
9
+
10
+ # start recording
11
+ ffmpeg -f x11grab -draw_mouse 0 -video_size 1280x720 \
12
+ -i :$VIRTUAL_DISPLAY_NUM \
13
+ -codec:v libx264 -r 25 $OUTPUT_VIDEO \
14
+ > /dev/null 2>&1 < /dev/null & FFMPEG_PID=$!
15
+
16
+ # start the demo program
17
+ DISPLAY=:$VIRTUAL_DISPLAY_NUM QT_QPA_PLATFORM=xcb \
18
+ python demo.py "$@" --start-on-open --exit-on-end
19
+
20
+ # kill the recording
21
+ kill $FFMPEG_PID
22
+
23
+ # kill xvfb server
24
+ kill $XVFB_PID
25
+
26
+ # success msg
27
+ echo -e "Recording saved: $OUTPUT_VIDEO"
eval.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typer
2
+ from typing import Optional
3
+ from pathlib import Path
4
+ from loguru import logger
5
+ import cv2
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import pandas as pd
9
+ import shutil
10
+ from datetime import datetime
11
+ import matplotlib
12
+ import os
13
+
14
+ matplotlib.use("Agg") # use non-interactive backend
15
+ import matplotlib.pyplot as plt
16
+
17
+ from predict import Model
18
+
19
+ app = typer.Typer()
20
+
21
+
22
+ @app.command(help="Export videos to images (to a dir per video)")
23
+ def export_videos_to_images(
24
+ input_dir: Path = typer.Argument(..., help="Input directory"),
25
+ output_dir: Path = typer.Argument(..., help="Output directory"),
26
+ ext: str = typer.Option("avi", help="Video Extension"),
27
+ path_filter: Optional[str] = typer.Option(None, help="input path filter"),
28
+ patient_prefix: Optional[bool] = typer.Option(
29
+ True, help="use patient info as output dir prefix"
30
+ ),
31
+ copy_extent: Optional[bool] = typer.Option(
32
+ True, help="copy extent files to output dir"
33
+ ),
34
+ ):
35
+ # log all the arguments passed in
36
+ logger.info(f"Function called with arguments: {locals()}")
37
+
38
+ # find all video files in input_dir
39
+ input_dir = Path(input_dir)
40
+ output_dir = Path(output_dir)
41
+ output_dir.mkdir(parents=True, exist_ok=True)
42
+ video_files = list(input_dir.glob(f"**/*.{ext.lower()}"))
43
+ video_files.extend(list(input_dir.glob(f"**/*.{ext.upper()}")))
44
+ logger.info(f"# of avi videos found: {len(video_files)}")
45
+ if path_filter is not None:
46
+ logger.info(f"Filtering videos with {path_filter}")
47
+ video_files = [x for x in video_files if path_filter in str(x)]
48
+ logger.info(f"# of avi videos found after filtering: {len(video_files)}")
49
+
50
+ video_files.sort(key=lambda x: x.name) # sort by name ascending
51
+ # log each video path after filtering, one per line
52
+ logger.info(f"{os.linesep}" + f"{os.linesep}".join([str(x) for x in video_files]))
53
+
54
+ # check that all the extent files exist
55
+ # the extent (.csv) should be in the same directory as the video
56
+ # the video filename would start with video_
57
+ # the extent filename would start with extents_
58
+ if copy_extent:
59
+ all_exist = True
60
+ for video_path in video_files:
61
+ extent_filename = video_path.stem.replace("video_", "extents_")
62
+ extent_path = video_path.parent / f"{extent_filename}.csv"
63
+ if not extent_path.exists():
64
+ logger.error(f"Extent file {extent_path} does not exist")
65
+ all_exist = False
66
+ if not all_exist:
67
+ logger.error("Extent files do not exist for all videos")
68
+ return
69
+
70
+ for video_path in video_files:
71
+ # copy extent file to output dir
72
+ if copy_extent:
73
+ extent_filename = video_path.stem.replace("video_", "extents_")
74
+ extent_path = video_path.parent / f"{extent_filename}.csv"
75
+ shutil.copy(extent_path, output_dir)
76
+
77
+ # Dir structure: Patient_Info / [PATIENT_ID] / [DATE] / video / xxx.avi
78
+ patient_id = (
79
+ video_path.parent.parent.parent.name
80
+ ) # WARNING: Hard-coded based on dir structure
81
+
82
+ video_name = video_path.stem
83
+ logger.info(f"Processing video {video_name} of patient {patient_id}")
84
+
85
+ # create subdirectory for each video
86
+ sub_dir = output_dir / (
87
+ f"{patient_id}-{video_name}" if patient_prefix else video_name
88
+ )
89
+ sub_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ # read video and export frames
92
+ cap = cv2.VideoCapture(str(video_path))
93
+ frame_count = 0
94
+ while cap.isOpened():
95
+ ret, frame = cap.read()
96
+ if ret:
97
+ # padding frame_count with zeros
98
+ cv2.imwrite(str(sub_dir / f"{frame_count:03}.jpg"), frame)
99
+ frame_count += 1
100
+ else:
101
+ break
102
+
103
+
104
+ @app.command(help="Evaluate model on a directory of images")
105
+ def eval(
106
+ input_dir: Path = typer.Argument(..., help="Input directory"),
107
+ input_model: Path = typer.Argument(..., help="Input model"),
108
+ imgsz: int = typer.Option(640, help="Image size"),
109
+ class_id: int = typer.Option(0, help="Class id to filter"),
110
+ conf_thresh: float = typer.Option(0.5, help="Confidence threshold"),
111
+ video_ext: str = typer.Option("avi", help="Video Extension"),
112
+ out_dir: Path = typer.Option("runs", help="Output directory"),
113
+ gt_csv_path: Path = typer.Option(
114
+ "results_20230822_aorta_identified_added_by_Ray.csv",
115
+ help="Ground truth csv path",
116
+ ),
117
+ no_extent: Optional[bool] = typer.Option(True, help="no extent file"),
118
+ write_viz: Optional[bool] = typer.Option(False, help="write viz images"),
119
+ gt_column_name: str = typer.Option("aorta_identified", help="Ground truth column"),
120
+ ):
121
+ # check inputs are valid
122
+ assert input_dir.exists(), f"Input directory {input_dir} does not exist"
123
+ assert input_model.exists(), f"Input model {input_model} does not exist"
124
+ assert gt_csv_path.exists(), f"Ground truth csv {gt_csv_path} does not exist"
125
+
126
+ # load model
127
+ model = Model(
128
+ model_path=str(input_model),
129
+ imgsz=imgsz,
130
+ classes=[class_id], # filter by class id, only aorta
131
+ device="CPU",
132
+ plot_mask=True,
133
+ conf_thres=conf_thresh,
134
+ is_async=False,
135
+ n_jobs=1,
136
+ )
137
+
138
+ # setup output directory
139
+ out_dir = Path(out_dir)
140
+ # create a sub output directory of current date and time
141
+ start_t = datetime.now()
142
+ start_timestamp = start_t.strftime("%Y_%m_%d_%H_%M_%S")
143
+ out_dir = out_dir / f"max_aorta_result-{start_timestamp}"
144
+ out_dir.mkdir(parents=True, exist_ok=True)
145
+
146
+ # log to file
147
+ logger.add(str(out_dir.absolute()) + "/eval_{time}.log")
148
+
149
+ out_csv_p = out_dir / "results.csv"
150
+ out_trace_csv_p = out_dir / "trace.csv"
151
+ logger.info(f"Output directory: {out_dir}")
152
+
153
+ # find all directories in input_dir
154
+ input_dir = Path(input_dir)
155
+ sub_dirs = [x for x in input_dir.iterdir() if x.is_dir()]
156
+ sub_dirs.sort(key=lambda x: x.name) # sort sub_dirs by name ascending
157
+ logger.info(f"# of subdirectories found: {len(sub_dirs)}")
158
+ num_sub_dirs = len(sub_dirs)
159
+ has_patient_prefix = False if sub_dirs[0].name.startswith("video") else True
160
+
161
+ # setup csv headers
162
+ trace_headers = ["video", "image_idx", "aorta_pixels", "aorta_mm", "conf"]
163
+ headers = ["video", "max_aorta_pixels", "max_aorta_mm", "max_image_idx", "conf"]
164
+ if has_patient_prefix:
165
+ headers.insert(0, "patient_info")
166
+
167
+ # loop through each subdirectory of images
168
+ for idx, sub_dir in enumerate(sub_dirs):
169
+ max_aorta_w = -1 # max aorta width in pixels
170
+ max_aorta_w_mm = -1 # max aorta width in mm
171
+ max_aorta_viz = None
172
+ max_aorta_im_path = None
173
+ max_center_x, max_center_y = -1, -1
174
+ max_conf = None
175
+ max_im_n = ""
176
+
177
+ # read the extent file of the images
178
+ # the extent file should be in the same directory as the video
179
+ video_filename = (
180
+ sub_dir.name
181
+ if not has_patient_prefix
182
+ else "-".join(sub_dir.name.split("-")[1:])
183
+ )
184
+ extent_filename = video_filename.replace("video_", "extents_")
185
+ extent_file = sub_dir.parent / f"{extent_filename}.csv"
186
+ extents = None
187
+ if not no_extent:
188
+ assert extent_file.exists(), f"Extent file {extent_file} does not exist"
189
+ extents = pd.read_csv(extent_file).to_dict("records")
190
+
191
+ logger.info(f"Processing subdir {sub_dir.name} ({idx+1}/{num_sub_dirs})")
192
+ # find all images in sub_dir
193
+ images = list(sub_dir.glob("*.jpg"))
194
+ # Sort the list of images in ascending order by name
195
+ images.sort(key=lambda img: img.name)
196
+ logger.info(f"\t# of images found: {len(images)}")
197
+
198
+ # create a viz output directory for each sub_dir
199
+ out_sub_viz_dir = out_dir / sub_dir.name
200
+ Path(out_sub_viz_dir).mkdir(parents=True, exist_ok=True)
201
+
202
+ for im_idx, image_path in enumerate(tqdm(images)):
203
+ # read image
204
+ cv_frame = cv2.imread(str(image_path))
205
+ cv_width = cv_frame.shape[1]
206
+
207
+ # inference
208
+ viz_frame, results = model.predict(cv_frame)
209
+ bbox_xyxy = results[0]
210
+ conf = results[1]
211
+ masks = results[2]
212
+
213
+ # output viz image if the flag is set
214
+ if write_viz:
215
+ cv2.imwrite(
216
+ str(out_sub_viz_dir / f"{image_path.stem}_viz.jpg"),
217
+ viz_frame,
218
+ )
219
+
220
+ trace_row = [
221
+ f"{sub_dir.name}.{video_ext}",
222
+ image_path.stem,
223
+ -1,
224
+ -1,
225
+ conf,
226
+ ]
227
+
228
+ if masks is not None or bbox_xyxy is not None:
229
+ # method 1: find the largest contour
230
+ # find min enclosing circle of mask
231
+ # mask = (masks * 255).astype(np.uint8)
232
+ # contours, _ = cv2.findContours(
233
+ # mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
234
+ # )
235
+ # largest_contour = max(contours, key=cv2.contourArea)
236
+ # (center_x, center_y), radius = cv2.minEnclosingCircle(largest_contour)
237
+ # aorta_width = radius * 2
238
+
239
+ # method 2: use the height of the bbox as a measure of aorta width
240
+ # because we observed that the width of the bbox is too large
241
+ aorta_width = bbox_xyxy[3] - bbox_xyxy[1]
242
+
243
+ # get physical unit
244
+ w_mm_left, w_mm_right, w_mm_per_pixel = None, None, None
245
+ if not no_extent:
246
+ w_mm_left = extents[im_idx]["Width-Left(mm)"]
247
+ w_mm_right = extents[im_idx]["Width-Right(mm)"]
248
+ assert w_mm_right > 0 and w_mm_left < 0
249
+ w_mm_per_pixel = (w_mm_right - w_mm_left) / cv_width
250
+
251
+ # update trace when aorta is found
252
+ trace_row[2] = aorta_width
253
+ trace_row[3] = aorta_width * w_mm_per_pixel if not no_extent else None
254
+
255
+ # output viz image when aorta is found
256
+ cv2.imwrite(
257
+ str(out_sub_viz_dir / f"{image_path.stem}_viz.jpg"),
258
+ viz_frame,
259
+ )
260
+ # copy the raw image to the output directory
261
+ shutil.copy(image_path, out_sub_viz_dir)
262
+
263
+ if aorta_width > max_aorta_w:
264
+ max_aorta_w = aorta_width
265
+ max_aorta_viz = viz_frame.copy()
266
+ max_aorta_im_path = image_path
267
+
268
+ # Note: only need to calculate the center if using method 1
269
+ # max_center_x = center_x
270
+ # max_center_y = center_y
271
+
272
+ max_im_n = image_path.stem
273
+ max_conf = conf
274
+ logger.info(
275
+ f"\tNew max aorta (pixels): {max_aorta_w:.2f}, conf: {max_conf:.2f}"
276
+ )
277
+
278
+ # convert pixels to mm
279
+ max_aorta_w_mm = (
280
+ max_aorta_w * w_mm_per_pixel if not no_extent else None
281
+ )
282
+
283
+ # save trace to csv
284
+ df = pd.DataFrame([trace_row], columns=trace_headers)
285
+ df.to_csv(
286
+ out_trace_csv_p,
287
+ mode="a",
288
+ header=not out_trace_csv_p.exists(),
289
+ index=False,
290
+ float_format="%.3f",
291
+ )
292
+
293
+ if max_aorta_w > 0:
294
+ logger.info(f"\tMax aorta (pixels): {max_aorta_w:.2f}")
295
+ # copy the raw image to the output directory
296
+ out_raw_p = out_dir / f"raw_{sub_dir.name}_{max_im_n}.jpg"
297
+ shutil.copy(max_aorta_im_path, out_raw_p)
298
+
299
+ # method 1 viz: draw enclosing circle on max_aorta_viz
300
+ # plot circle on max_aorta_viz
301
+ # cv2.circle(
302
+ # max_aorta_viz,
303
+ # (int(max_center_x), int(max_center_y)),
304
+ # int(max_aorta_w / 2),
305
+ # (0, 255, 0),
306
+ # 2,
307
+ # )
308
+
309
+ # region Save the image with extent
310
+ # convert the BGR image to RGB image
311
+ out_viz_p = out_dir / f"viz_{sub_dir.name}_{max_im_n}.jpg"
312
+ max_aorta_viz_rgb = cv2.cvtColor(max_aorta_viz, cv2.COLOR_BGR2RGB)
313
+ # Use matplotlib to save the image
314
+ # Get the size of the image in inches
315
+ dpi = plt.rcParams["figure.dpi"] # Get the default dpi value
316
+ figsize = (
317
+ max_aorta_viz_rgb.shape[1] / dpi,
318
+ max_aorta_viz_rgb.shape[0] / dpi,
319
+ ) # width, height
320
+ # Create a new figure with the same aspect ratio as the image
321
+ fig = plt.figure(figsize=figsize)
322
+ if not no_extent:
323
+ # specify the extent of the image in the form [xmin, xmax, ymin, ymax]
324
+ extent = [
325
+ extents[im_idx]["Width-Left(mm)"],
326
+ extents[im_idx]["Width-Right(mm)"],
327
+ extents[im_idx]["Depth-Bottom(mm)"],
328
+ extents[im_idx]["Depth-Top(mm)"],
329
+ ]
330
+ plt.imshow(max_aorta_viz_rgb, extent=extent)
331
+ plt.xlabel("Width [mm]")
332
+ plt.ylabel("Depth [mm]")
333
+ else:
334
+ plt.imshow(max_aorta_viz_rgb)
335
+ plt.savefig(str(out_viz_p))
336
+ plt.close(fig)
337
+ # cv2.imwrite(str(out_viz_p), max_aorta_viz)
338
+ # endregion
339
+ else:
340
+ logger.warning(f"\tNo aorta found in {sub_dir.name}")
341
+ patient_info = sub_dir.name.split("-")[0] if has_patient_prefix else ""
342
+ row = [
343
+ f"{sub_dir.name}.{video_ext}",
344
+ max_aorta_w,
345
+ max_aorta_w_mm,
346
+ max_im_n,
347
+ max_conf,
348
+ ]
349
+ if has_patient_prefix:
350
+ row.insert(0, patient_info)
351
+ # remove patient info from sub_dir name
352
+ video_name = "-".join(sub_dir.name.split("-")[1:]) + f".{video_ext}"
353
+ row[1] = video_name
354
+
355
+ # export results to csv
356
+ # If file does not exist, this will create it, otherwise it will append to the file
357
+ df = pd.DataFrame([row], columns=headers)
358
+ df.to_csv(
359
+ out_csv_p,
360
+ mode="a",
361
+ header=not out_csv_p.exists(),
362
+ index=False,
363
+ float_format="%.3f",
364
+ )
365
+
366
+ # join the results with ground truth to add the ground truth column
367
+ # df_results = pd.read_csv(out_csv_p)
368
+ # df_gt = pd.read_csv(gt_csv_path)[["video", gt_column_name]] # id & gt columns
369
+ # df_gt_first = df_gt.drop_duplicates(subset="video", keep="first") # avoid new rows
370
+ # df_merged = pd.merge(df_results, df_gt_first, on="video", how="left")
371
+ # df_merged.to_csv(out_csv_p, header=True, index=False, float_format="%.3f")
372
+
373
+ # # show stats
374
+ # value_counts_with_nan = df_merged[gt_column_name].value_counts(dropna=False)
375
+ # total = len(df_merged)
376
+ # percentage = (value_counts_with_nan / total) * 100
377
+ # # Combine value counts and percentages into a DataFrame for better visualization
378
+ # stats = pd.DataFrame({"Count": value_counts_with_nan, "Percentage": percentage})
379
+ # logger.info(stats)
380
+
381
+ logger.info(f"Done! Results written to {out_csv_p}")
382
+
383
+
384
+ @app.command(help="Copy source images to viz result folder")
385
+ def copy_srcimg_to_vizdir(
386
+ src_img_dir: Path = typer.Argument(..., help="Source Images root directory"),
387
+ out_viz_dir: Path = typer.Argument(..., help="Target viz dirtectory"),
388
+ ):
389
+ vizs = list(Path(out_viz_dir).glob("**/*.jpg"))
390
+ for viz in vizs:
391
+ splits = viz.stem.split("_")
392
+ ori_img = Path(src_img_dir) / splits[1] / f"{splits[2]}.jpg"
393
+ shutil.copy(ori_img, Path(out_viz_dir) / f"{viz.stem}_src.jpg")
394
+
395
+
396
+ if __name__ == "__main__":
397
+ app()
plots.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+
8
+ class Colors:
9
+ def __init__(self):
10
+ # hexs = matplotlib.colors.TABLEAU_COLORS.values()
11
+ hexs = (
12
+ "00FF00", # aorta class 0
13
+ "FF3838",
14
+ "FF701F",
15
+ "FFB21D",
16
+ "CFD231",
17
+ "48F90A",
18
+ "92CC17",
19
+ "3DDB86",
20
+ "1A9334",
21
+ "00D4BB",
22
+ "2C99A8",
23
+ "00C2FF",
24
+ "344593",
25
+ "6473FF",
26
+ "0018EC",
27
+ "8438FF",
28
+ "520085",
29
+ "CB38FF",
30
+ "FF95C8",
31
+ "FF37C7",
32
+ )
33
+ self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
34
+ self.n = len(self.palette)
35
+
36
+ def __call__(self, i, bgr=False):
37
+ c = self.palette[int(i) % self.n]
38
+ return (c[2], c[1], c[0]) if bgr else c
39
+
40
+ @staticmethod
41
+ def hex2rgb(h): # rgb order (PIL)
42
+ return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
43
+
44
+
45
+ colors = Colors() # create instance for 'from utils.plots import colors'
46
+
47
+
48
+ def is_ascii(s=""):
49
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
50
+ s = str(s) # convert list, tuple, None, etc. to str
51
+ return len(s.encode().decode("ascii", "ignore")) == len(s)
52
+
53
+
54
+ def clip_boxes(boxes, shape):
55
+ # Clip boxes (xyxy) to image shape (height, width)
56
+ if isinstance(boxes, torch.Tensor): # faster individually
57
+ boxes[:, 0].clamp_(0, shape[1]) # x1
58
+ boxes[:, 1].clamp_(0, shape[0]) # y1
59
+ boxes[:, 2].clamp_(0, shape[1]) # x2
60
+ boxes[:, 3].clamp_(0, shape[0]) # y2
61
+ else: # np.array (faster grouped)
62
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
63
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
64
+
65
+
66
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
67
+ # Rescale boxes (xyxy) from img1_shape to img0_shape
68
+ if ratio_pad is None: # calculate from img0_shape
69
+ gain = min(
70
+ img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]
71
+ ) # gain = old / new
72
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (
73
+ img1_shape[0] - img0_shape[0] * gain
74
+ ) / 2 # wh padding
75
+ else:
76
+ gain = ratio_pad[0][0]
77
+ pad = ratio_pad[1]
78
+
79
+ boxes[:, [0, 2]] -= pad[0] # x padding
80
+ boxes[:, [1, 3]] -= pad[1] # y padding
81
+ boxes[:, :4] /= gain
82
+ clip_boxes(boxes, img0_shape)
83
+ return boxes
84
+
85
+
86
+ def crop_mask(masks, boxes):
87
+ """
88
+ "Crop" predicted masks by zeroing out everything not in the predicted bbox.
89
+ Vectorized by Chong (thanks Chong).
90
+ Args:
91
+ - masks should be a size [h, w, n] tensor of masks
92
+ - boxes should be a size [n, 4] tensor of bbox coords in relative point form
93
+ """
94
+
95
+ n, h, w = masks.shape
96
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
97
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[
98
+ None, None, :
99
+ ] # rows shape(1,w,1)
100
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[
101
+ None, :, None
102
+ ] # cols shape(h,1,1)
103
+
104
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
105
+
106
+
107
+ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
108
+ """
109
+ Crop before upsample.
110
+ proto_out: [mask_dim, mask_h, mask_w]
111
+ out_masks: [n, mask_dim], n is number of masks after nms
112
+ bboxes: [n, 4], n is number of masks after nms
113
+ shape:input_image_size, (h, w)
114
+ return: h, w, n
115
+ """
116
+
117
+ c, mh, mw = protos.shape # CHW
118
+ ih, iw = shape
119
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
120
+
121
+ downsampled_bboxes = bboxes.clone()
122
+ downsampled_bboxes[:, 0] *= mw / iw
123
+ downsampled_bboxes[:, 2] *= mw / iw
124
+ downsampled_bboxes[:, 3] *= mh / ih
125
+ downsampled_bboxes[:, 1] *= mh / ih
126
+
127
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
128
+ if upsample:
129
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[
130
+ 0
131
+ ] # CHW
132
+ return masks.gt_(0.5)
133
+
134
+
135
+ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
136
+ """
137
+ img1_shape: model input shape, [h, w]
138
+ img0_shape: origin pic shape, [h, w, 3]
139
+ masks: [h, w, num]
140
+ """
141
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
142
+ if ratio_pad is None: # calculate from im0_shape
143
+ gain = min(
144
+ im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]
145
+ ) # gain = old / new
146
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (
147
+ im1_shape[0] - im0_shape[0] * gain
148
+ ) / 2 # wh padding
149
+ else:
150
+ pad = ratio_pad[1]
151
+ top, left = int(pad[1]), int(pad[0]) # y, x
152
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
153
+
154
+ if len(masks.shape) < 2:
155
+ raise ValueError(
156
+ f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}'
157
+ )
158
+ masks = masks[top:bottom, left:right]
159
+ # masks = masks.permute(2, 0, 1).contiguous()
160
+ # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
161
+ # masks = masks.permute(1, 2, 0).contiguous()
162
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
163
+
164
+ if len(masks.shape) == 2:
165
+ masks = masks[:, :, None]
166
+ return masks
167
+
168
+
169
+ class Annotator:
170
+ # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
171
+ def __init__(
172
+ self,
173
+ im,
174
+ line_width=None,
175
+ font_size=None,
176
+ font="Arial.ttf",
177
+ pil=False,
178
+ example="abc",
179
+ ):
180
+ assert (
181
+ im.data.contiguous
182
+ ), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
183
+ non_ascii = not is_ascii(
184
+ example
185
+ ) # non-latin labels, i.e. asian, arabic, cyrillic
186
+ self.pil = pil or non_ascii
187
+ if self.pil: # use PIL
188
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
189
+ self.draw = ImageDraw.Draw(self.im)
190
+ self.font = check_pil_font(
191
+ font="Arial.Unicode.ttf" if non_ascii else font,
192
+ size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12),
193
+ )
194
+ else: # use cv2
195
+ self.im = im
196
+ self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
197
+
198
+ def box_label(
199
+ self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
200
+ ):
201
+ # Add one xyxy box to image with label
202
+ if self.pil or not is_ascii(label):
203
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
204
+ if label:
205
+ w, h = self.font.getsize(label) # text width, height
206
+ outside = box[1] - h >= 0 # label fits outside box
207
+ self.draw.rectangle(
208
+ (
209
+ box[0],
210
+ box[1] - h if outside else box[1],
211
+ box[0] + w + 1,
212
+ box[1] + 1 if outside else box[1] + h + 1,
213
+ ),
214
+ fill=color,
215
+ )
216
+ # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
217
+ self.draw.text(
218
+ (box[0], box[1] - h if outside else box[1]),
219
+ label,
220
+ fill=txt_color,
221
+ font=self.font,
222
+ )
223
+ else: # cv2
224
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
225
+ cv2.rectangle(
226
+ self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA
227
+ )
228
+ if label:
229
+ tf = max(self.lw - 1, 1) # font thickness
230
+ w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[
231
+ 0
232
+ ] # text width, height
233
+ outside = p1[1] - h >= 3
234
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
235
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
236
+ cv2.putText(
237
+ self.im,
238
+ label,
239
+ (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
240
+ 0,
241
+ self.lw / 3,
242
+ txt_color,
243
+ thickness=tf,
244
+ lineType=cv2.LINE_AA,
245
+ )
246
+
247
+ def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
248
+ """Plot masks at once.
249
+ Args:
250
+ masks (tensor): predicted masks on cuda, shape: [n, h, w]
251
+ colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
252
+ im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
253
+ alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
254
+ """
255
+ im_gpu = torch.from_numpy(im_gpu) # not sure why we need this fix?
256
+ # print(im_gpu)
257
+ if self.pil:
258
+ # convert to numpy first
259
+ self.im = np.asarray(self.im).copy()
260
+ if len(masks) == 0:
261
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
262
+ colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
263
+ colors = colors[:, None, None] # shape(n,1,1,3)
264
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
265
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
266
+
267
+ inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
268
+ mcs = (masks_color * inv_alph_masks).sum(
269
+ 0
270
+ ) * 2 # mask color summand shape(n,h,w,3)
271
+
272
+ im_gpu = im_gpu.flip(dims=[0]) # flip channel
273
+ im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
274
+ im_gpu = im_gpu * inv_alph_masks[-1] + mcs
275
+ im_mask = (im_gpu * 255).byte().cpu().numpy()
276
+ self.im[:] = (
277
+ im_mask
278
+ if retina_masks
279
+ else scale_image(im_gpu.shape, im_mask, self.im.shape)
280
+ )
281
+ if self.pil:
282
+ # convert im back to PIL and update draw
283
+ self.fromarray(self.im)
284
+
285
+ def rectangle(self, xy, fill=None, outline=None, width=1):
286
+ # Add rectangle to image (PIL-only)
287
+ self.draw.rectangle(xy, fill, outline, width)
288
+
289
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"):
290
+ # Add text to image (PIL-only)
291
+ if anchor == "bottom": # start y from font bottom
292
+ w, h = self.font.getsize(text) # text width, height
293
+ xy[1] += 1 - h
294
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
295
+
296
+ def fromarray(self, im):
297
+ # Update self.im from a numpy array
298
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
299
+ self.draw = ImageDraw.Draw(self.im)
300
+
301
+ def result(self):
302
+ # Return annotated image as array
303
+ return np.asarray(self.im)
predict.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Must import torch before onnxruntime, else could not create cuda context
2
+ # ref: https://github.com/microsoft/onnxruntime/issues/11092#issuecomment-1386840174
3
+ import torch, torchvision
4
+ import onnxruntime
5
+
6
+ from time import perf_counter
7
+ from openvino.runtime import Core, Layout, get_batch, AsyncInferQueue
8
+ from pathlib import Path
9
+ import yaml
10
+ import cv2
11
+ import numpy as np
12
+ import time
13
+ from plots import Annotator, process_mask, scale_boxes, scale_image, colors
14
+ from loguru import logger
15
+
16
+
17
+ def from_numpy(x):
18
+ return torch.from_numpy(x) if isinstance(x, np.ndarray) else x
19
+
20
+
21
+ def yaml_load(file="data.yaml"):
22
+ # Single-line safe yaml loading
23
+ with open(file, errors="ignore") as f:
24
+ return yaml.safe_load(f)
25
+
26
+
27
+ def load_metadata(f=Path("path/to/meta.yaml")):
28
+ # Load metadata from meta.yaml if it exists
29
+ if f.exists():
30
+ d = yaml_load(f)
31
+ return d["stride"], d["names"] # assign stride, names
32
+ return None, None
33
+
34
+
35
+ def letterbox(
36
+ im,
37
+ new_shape=(640, 640),
38
+ color=(114, 114, 114),
39
+ auto=True,
40
+ scale_fill=False,
41
+ scaleup=True,
42
+ stride=32,
43
+ ):
44
+ # Resize and pad image while meeting stride-multiple constraints
45
+ shape = im.shape[:2] # current shape [height, width]
46
+ if isinstance(new_shape, int):
47
+ new_shape = (new_shape, new_shape)
48
+
49
+ # Scale ratio (new / old)
50
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
51
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
52
+ r = min(r, 1.0)
53
+
54
+ # Compute padding
55
+ ratio = r, r # width, height ratios
56
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
57
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
58
+ if auto: # minimum rectangle
59
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
60
+ elif scale_fill: # stretch
61
+ dw, dh = 0.0, 0.0
62
+ new_unpad = (new_shape[1], new_shape[0])
63
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
64
+
65
+ dw /= 2 # divide padding into 2 sides
66
+ dh /= 2
67
+
68
+ if shape[::-1] != new_unpad: # resize
69
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
70
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
71
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
72
+ im = cv2.copyMakeBorder(
73
+ im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
74
+ ) # add border
75
+ return im, ratio, (dw, dh)
76
+
77
+
78
+ def xywh2xyxy(x):
79
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
80
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
81
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
82
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
83
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
84
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
85
+ return y
86
+
87
+
88
+ def box_iou(box1, box2, eps=1e-7):
89
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
90
+ """
91
+ Return intersection-over-union (Jaccard index) of boxes.
92
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
93
+ Arguments:
94
+ box1 (Tensor[N, 4])
95
+ box2 (Tensor[M, 4])
96
+ Returns:
97
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
98
+ IoU values for every element in boxes1 and boxes2
99
+ """
100
+
101
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
102
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
103
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
104
+
105
+ # IoU = inter / (area1 + area2 - inter)
106
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
107
+
108
+
109
+ def non_max_suppression(
110
+ prediction,
111
+ conf_thres=0.25,
112
+ iou_thres=0.45,
113
+ classes=None,
114
+ agnostic=False,
115
+ multi_label=False,
116
+ labels=(),
117
+ max_det=300,
118
+ nm=0, # number of masks
119
+ redundant=True, # require redundant detections
120
+ ):
121
+ """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
122
+ Returns:
123
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
124
+ """
125
+
126
+ if isinstance(
127
+ prediction, (list, tuple)
128
+ ): # YOLOv5 model in validation model, output = (inference_out, loss_out)
129
+ prediction = prediction[0] # select only inference output
130
+
131
+ device = prediction.device
132
+ mps = "mps" in device.type # Apple MPS
133
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
134
+ prediction = prediction.cpu()
135
+ bs = prediction.shape[0] # batch size
136
+ nc = prediction.shape[2] - nm - 5 # number of classes
137
+ xc = prediction[..., 4] > conf_thres # candidates
138
+
139
+ # Checks
140
+ assert (
141
+ 0 <= conf_thres <= 1
142
+ ), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
143
+ assert (
144
+ 0 <= iou_thres <= 1
145
+ ), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
146
+
147
+ # Settings
148
+ # min_wh = 2 # (pixels) minimum box width and height
149
+ max_wh = 7680 # (pixels) maximum box width and height
150
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
151
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
152
+ merge = False # use merge-NMS
153
+
154
+ t = time.time()
155
+ mi = 5 + nc # mask start index
156
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
157
+ for xi, x in enumerate(prediction): # image index, image inference
158
+ # Apply constraints
159
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
160
+ x = x[xc[xi]] # confidence
161
+
162
+ # Cat apriori labels if autolabelling
163
+ if labels and len(labels[xi]):
164
+ lb = labels[xi]
165
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
166
+ v[:, :4] = lb[:, 1:5] # box
167
+ v[:, 4] = 1.0 # conf
168
+ v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
169
+ x = torch.cat((x, v), 0)
170
+
171
+ # If none remain process next image
172
+ if not x.shape[0]:
173
+ continue
174
+
175
+ # Compute conf
176
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
177
+
178
+ # Box/Mask
179
+ box = xywh2xyxy(
180
+ x[:, :4]
181
+ ) # center_x, center_y, width, height) to (x1, y1, x2, y2)
182
+ mask = x[:, mi:] # zero columns if no masks
183
+
184
+ # Detections matrix nx6 (xyxy, conf, cls)
185
+ if multi_label:
186
+ i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
187
+ x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
188
+ else: # best class only
189
+ conf, j = x[:, 5:mi].max(1, keepdim=True)
190
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
191
+
192
+ # Filter by class
193
+ if classes is not None:
194
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
195
+
196
+ # Apply finite constraint
197
+ # if not torch.isfinite(x).all():
198
+ # x = x[torch.isfinite(x).all(1)]
199
+
200
+ # Check shape
201
+ n = x.shape[0] # number of boxes
202
+ if not n: # no boxes
203
+ continue
204
+ elif n > max_nms: # excess boxes
205
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
206
+ else:
207
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
208
+
209
+ # Batched NMS
210
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
211
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
212
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
213
+ if i.shape[0] > max_det: # limit detections
214
+ i = i[:max_det]
215
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
216
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
217
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
218
+ weights = iou * scores[None] # box weights
219
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
220
+ 1, keepdim=True
221
+ ) # merged boxes
222
+ if redundant:
223
+ i = i[iou.sum(1) > 1] # require redundancy
224
+
225
+ output[xi] = x[i]
226
+ if mps:
227
+ output[xi] = output[xi].to(device)
228
+
229
+ return output
230
+
231
+
232
+ class Model:
233
+ def __init__(
234
+ self,
235
+ model_path,
236
+ imgsz=320,
237
+ classes=None,
238
+ device="CPU",
239
+ plot_mask=False,
240
+ conf_thres=0.7,
241
+ n_jobs=1,
242
+ is_async=False,
243
+ ):
244
+ # filter by class: classes=[0], or classes=[0, 2, 3]
245
+ model_type = "onnx" if Path(model_path).suffix == ".onnx" else "openvino"
246
+ assert Path(model_path).exists(), f"Model {model_path} not found"
247
+ assert Path(model_path).suffix in (
248
+ ".onnx",
249
+ ".xml",
250
+ ), "Model must be .onnx or .xml"
251
+ self.model_type = model_type
252
+ self.model_path = model_path
253
+ self.imgsz = imgsz
254
+ self.classes = classes
255
+ self.plot_mask = plot_mask
256
+ self.conf_thres = conf_thres
257
+
258
+ # async settings
259
+ self.n_jobs = n_jobs
260
+ self.is_async = is_async
261
+ self.completed_results = {} # key: frame_id, value: inference results
262
+ self.ori_cv_imgs = {} # key: frame_id, value: original cv image
263
+ self.prep_cv_imgs = {} # key: frame_id, value: preprocessed cv image
264
+
265
+ if self.model_type == "onnx":
266
+ assert is_async is False, "Async mode is not supported for ONNX models"
267
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
268
+ session = onnxruntime.InferenceSession(model_path, providers=providers)
269
+ self.session = session
270
+ output_names = [x.name for x in session.get_outputs()]
271
+ self.output_names = output_names
272
+ meta = session.get_modelmeta().custom_metadata_map # metadata
273
+ if "stride" in meta:
274
+ stride, names = int(meta["stride"]), eval(meta["names"])
275
+ self.stride = stride
276
+ self.names = names
277
+ elif self.model_type == "openvino":
278
+ # load OpenVINO model
279
+ assert Path(model_path).suffix == ".xml", "OpenVINO model must be .xml"
280
+ ie = Core()
281
+ weights = Path(model_path).with_suffix(".bin").as_posix()
282
+ network = ie.read_model(model=model_path, weights=weights)
283
+ if network.get_parameters()[0].get_layout().empty:
284
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
285
+ batch_dim = get_batch(network)
286
+ if batch_dim.is_static:
287
+ batch_size = batch_dim.get_length()
288
+
289
+ # To run inference on M1, we must export the IR model using "mo --use_legacy_frontend"
290
+ # Otherwise, we would get the following error when compiling the model
291
+ # https://github.com/openvinotoolkit/openvino/issues/12476#issuecomment-1222202804
292
+ config = {}
293
+ if n_jobs == "auto":
294
+ config = {"PERFORMANCE_HINT": "THROUGHPUT"}
295
+ self.executable_network = ie.compile_model(
296
+ network, device_name=device, config=config
297
+ )
298
+ num_requests = self.executable_network.get_property(
299
+ "OPTIMAL_NUMBER_OF_INFER_REQUESTS"
300
+ )
301
+ self.n_jobs = num_requests if n_jobs == "auto" else int(n_jobs)
302
+ logger.info(f"Optimal number of infer requests should be: {num_requests}")
303
+ self.stride, self.names = load_metadata(
304
+ Path(weights).with_suffix(".yaml")
305
+ ) # load metadata
306
+
307
+ if is_async:
308
+ logger.info(f"Using num of infer requests jobs: {n_jobs}")
309
+ self.pipeline = AsyncInferQueue(self.executable_network, self.n_jobs)
310
+ self.pipeline.set_callback(self.callback)
311
+
312
+ def preprocess(self, cv_img, pt=False):
313
+ im = letterbox(cv_img, self.imgsz, stride=self.stride, auto=pt)[
314
+ 0
315
+ ] # padded resize
316
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
317
+ im = np.ascontiguousarray(im) # contiguous
318
+ im = torch.from_numpy(im)
319
+ im = im.float() # uint8 to fp16/32
320
+ im /= 255 # 0 - 255 to 0.0 - 1.0
321
+ if len(im.shape) == 3:
322
+ im = im[None] # expand for batch dim
323
+ im = im.cpu().numpy() # torch to numpy
324
+ return im
325
+
326
+ def postprocess(self, y, ori_cv_im, prep_im):
327
+ y = [from_numpy(x) for x in y]
328
+ pred, proto = y[0], y[-1]
329
+
330
+ im0 = ori_cv_im
331
+
332
+ # NMS
333
+ iou_thres = 0.45
334
+ agnostic_nms = False
335
+ max_det = 1 # maximum detections per image, only 1 aorta is needed
336
+ pred = non_max_suppression(
337
+ pred,
338
+ self.conf_thres,
339
+ iou_thres,
340
+ self.classes,
341
+ agnostic_nms,
342
+ max_det=max_det,
343
+ nm=32,
344
+ )
345
+
346
+ # Process predictions
347
+ line_thickness = 3
348
+ annotator = Annotator(
349
+ np.ascontiguousarray(im0),
350
+ line_width=line_thickness,
351
+ example=str(self.names),
352
+ )
353
+ i = 0
354
+ det = pred[0]
355
+ im = prep_im
356
+ r_xyxy, r_conf, r_masks = None, None, None
357
+ if len(pred[0]):
358
+ masks = process_mask(
359
+ proto[i],
360
+ det[:, 6:],
361
+ det[:, :4],
362
+ (self.imgsz, self.imgsz),
363
+ upsample=True,
364
+ ) # HWC
365
+ det[:, :4] = scale_boxes(
366
+ (self.imgsz, self.imgsz), det[:, :4], im0.shape
367
+ ).round() # rescale boxes to im0 size
368
+
369
+ # Mask plotting
370
+ if self.plot_mask:
371
+ annotator.masks(
372
+ masks,
373
+ colors=[colors(x, True) for x in det[:, 5]],
374
+ im_gpu=im[i],
375
+ alpha=0.1,
376
+ )
377
+
378
+ # Write results
379
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
380
+ # Add bbox to image
381
+ c = int(cls) # integer class
382
+ label = f"{self.names[c]} {conf:.2f}"
383
+ annotator.box_label(xyxy, label, color=colors(c, True))
384
+ r_xyxy = xyxy
385
+ r_conf = conf
386
+ r_xyxy = [i.int().numpy().item() for i in r_xyxy]
387
+ r_conf = r_conf.numpy().item()
388
+ r_masks = scale_image((self.imgsz, self.imgsz), masks.numpy()[0], im0.shape)
389
+ return annotator.result(), (r_xyxy, r_conf, r_masks)
390
+
391
+ def predict(self, cv_img):
392
+ # return the annotated image and the bounding box
393
+ result_cv_img, xyxy = None, None
394
+ im = self.preprocess(cv_img)
395
+ if self.model_type == "onnx":
396
+ y = self.session.run(
397
+ self.output_names, {self.session.get_inputs()[0].name: im}
398
+ )
399
+ elif self.model_type == "openvino":
400
+ # OpenVINO model inference
401
+ # Note: Please use FP32 model on M1, otherwise you will get many runtime errors
402
+ # Very slow on M1, but works
403
+ # start = perf_counter()
404
+ y = list(self.executable_network([im]).values())
405
+ # logger.info(f"OpenVINO inference time: {perf_counter() - start:.3f}s")
406
+ result_cv_img, others = self.postprocess(y, cv_img, im)
407
+ return result_cv_img, others
408
+
409
+ def callback(self, request, userdata):
410
+ # callback function for AsyncInferQueue
411
+ outputs = request.outputs
412
+ frame_id = userdata
413
+ self.completed_results[frame_id] = [i.data for i in outputs]
414
+
415
+ def predict_async(self, cv_img, frame_id):
416
+ assert self.is_async, "Please set is_async=True when initializing the model"
417
+ self.ori_cv_imgs[frame_id] = cv_img
418
+ im = self.preprocess(cv_img)
419
+ self.prep_cv_imgs[frame_id] = im
420
+
421
+ # Note: The start_async function call is not required to be synchronized - it waits for any available job if the queue is busy/overloaded.
422
+ # https://docs.openvino.ai/latest/openvino_docs_OV_UG_Python_API_exclusives.html#asyncinferqueue
423
+ #
424
+ # idle_id = self.pipeline.get_idle_request_id()
425
+ # self.pipeline.start_async({idle_id: im}, frame_id)
426
+ self.pipeline.start_async({0: im}, frame_id)
427
+
428
+ def is_free_to_infer_async(self):
429
+ """Returns True if any free request in the pool, otherwise False"""
430
+ assert self.is_async, "Please set is_async=True when initializing the model"
431
+ return self.pipeline.is_ready()
432
+
433
+ def get_result(self, frame_id):
434
+ """Returns the inference result for the given frame_id"""
435
+ assert self.is_async, "Please set is_async=True when initializing the model"
436
+ if frame_id in self.completed_results:
437
+ y = self.completed_results.pop(frame_id)
438
+ cv_img = self.ori_cv_imgs.pop(frame_id)
439
+ im = self.prep_cv_imgs.pop(frame_id)
440
+ result_cv_img, others = self.postprocess(y, cv_img, im)
441
+ return result_cv_img, others
442
+ return None
443
+
444
+
445
+ if __name__ == "__main__":
446
+ m_p = "weights/yolov7seg-JH-v1.onnx"
447
+ m_p = "weights/yolov5s-seg-MK-v1.onnx"
448
+ m_p = "weights/best_openvino_model/best.xml"
449
+ imgsz = 320
450
+ # imgsz = 640
451
+ model = Model(model_path=m_p, imgsz=imgsz)
452
+
453
+ # inference an image using the loaded model
454
+ # source = 'Tim_3-0-00-20.05.jpg'
455
+ path = "data/Jimmy_2-0-00-04.63.jpg"
456
+ assert Path(path).exists(), f"Input image {path} doesn't exist"
457
+
458
+ # output path
459
+ save_dir = "runs/predict"
460
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
461
+ out_p = f"{save_dir}/{Path(path).stem}.jpg"
462
+
463
+ # load image and preprocess
464
+ im0 = cv2.imread(path) # BGR
465
+ result_cv_img, _ = model.predict(im0)
466
+ if result_cv_img is not None:
467
+ cv2.imwrite(out_p, result_cv_img)
468
+ logger.info(f"Saved result to {out_p}")
469
+ else:
470
+ logger.error("No result, something went wrong")
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset processing
2
+ loguru==0.6.0
3
+ typer[all]==0.7.0
4
+ fiftyone==0.19.1
5
+ pycocotools==2.0.6
6
+
7
+ torch==1.13.0
8
+ torchvision==0.14.0
9
+ openvino==2022.2.0; sys_platform != "darwin"
10
+ openvino-arm==2022.1.0.1; sys_platform == "darwin"
11
+ opencv-python==4.6.0.66
12
+ PyYAML==6.0
13
+ onnx==1.13.1
14
+ onnxruntime==1.13.1
15
+ onnxruntime-gpu==1.13.1; sys_platform != "darwin"
16
+
17
+ # demo GUI
18
+ PySide6==6.4.1
19
+ scipy==1.9.3
20
+ matplotlib==3.5.2
21
+
22
+ # demo plot
23
+ plotly==5.11.0
24
+ pandas==1.5.2
25
+ kaleido==0.2.1; platform_system != "Windows"
26
+ kaleido==0.1.0post1; platform_system == "Windows"
roi.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+
7
+ # get the image from arguments
8
+ ap = argparse.ArgumentParser()
9
+ ap.add_argument("-i", "--image", required=True, help="Path to the image")
10
+ args = vars(ap.parse_args())
11
+
12
+ # check if the image does exist
13
+ img_p = Path(args["image"])
14
+ assert img_p.exists(), "Image does not exist"
15
+
16
+ # Read the image
17
+ img = cv2.imread(args["image"])
18
+
19
+ # Select the ROI from the image
20
+ ROI = cv2.selectROI("Image", img, False, False)
21
+
22
+ # Append the ROI coordinates to the csv file
23
+ # header: filename, ori_width, ori_height, roi_x, roi_y, roi_width, roi_height
24
+ with open("roi.csv", "a") as f:
25
+ # if no file exists, create a new one with the header
26
+ if f.tell() == 0:
27
+ f.write("filename,ori_width,ori_height,roi_x,roi_y,roi_width,roi_height\n")
28
+ ori_w, ori_h = img.shape[1], img.shape[0]
29
+ f.write(f"{img_p.name},{ori_w},{ori_h},{ROI[0]},{ROI[1]},{ROI[2]},{ROI[3]}\n")
30
+
31
+ # Display cropped image
32
+ cropped = img[ROI[1] : ROI[1] + ROI[3], ROI[0] : ROI[0] + ROI[2]]
33
+ cv2.imshow("Cropped Image", cropped)
34
+ cv2.waitKey(0)
try_chart.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
usgfw2wrapper.dll ADDED
Binary file (15.4 kB). View file
 
weights/.keep ADDED
@@ -0,0 +1 @@
 
 
1
+
weights/yolov5s-v2 ADDED
@@ -0,0 +1 @@
 
 
1
+