update_repo
Browse files- README.md +143 -3
- add_clearml_yolov5.patch +215 -0
- dataset.py +526 -0
- demo.bat +12 -0
- demo.py +781 -0
- demo_headless.sh +27 -0
- eval.py +397 -0
- plots.py +303 -0
- predict.py +470 -0
- requirements.txt +26 -0
- roi.py +34 -0
- try_chart.ipynb +0 -0
- usgfw2wrapper.dll +0 -0
- weights/.keep +1 -0
- weights/yolov5s-v2 +1 -0
README.md
CHANGED
@@ -1,3 +1,143 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- #region -->
|
2 |
+
# Automation of Aorta Measurement in Ultrasound Images
|
3 |
+
|
4 |
+
## Env setup
|
5 |
+
|
6 |
+
Suggested hardware:
|
7 |
+
|
8 |
+
- GPU: NVIDIA RTX 3090 or higher x1 (model training using PyTorch)
|
9 |
+
- CPU: 11th Gen Intel(R) Core(TM) i9-11900KF @ 3.50GHz, or higher (model inference using OpenVINO)
|
10 |
+
|
11 |
+
Software stack:
|
12 |
+
|
13 |
+
- OS: Ubuntu 20.04 LTS
|
14 |
+
- Python: 3.8+
|
15 |
+
- Python Env: conda
|
16 |
+
|
17 |
+
```shell
|
18 |
+
conda create -n aorta python=3.8 -y
|
19 |
+
conda activate aorta
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
## Dataset
|
24 |
+
|
25 |
+
Steps to prepare the dataset:
|
26 |
+
|
27 |
+
1. Collect images and import to CVAT
|
28 |
+
2. Label the images in CVAT
|
29 |
+
3. Export the labelled data in `COCO 1.0` format using CVAT
|
30 |
+
|
31 |
+
1. Go to CVAT > `Projects` page
|
32 |
+
2. Click `⋮` on `aorta` project
|
33 |
+
3. Click `Export dataset`
|
34 |
+
- Format: `COCO 1.0`
|
35 |
+
- Save images: `Yes`
|
36 |
+
|
37 |
+
4. Convert the new split data into YOLOv5 format
|
38 |
+
|
39 |
+
```shell
|
40 |
+
python dataset.py coco2yolov5 [path/to/coco/input/dir] [path/to/yolov5/output/dir]
|
41 |
+
```
|
42 |
+
|
43 |
+
[CVAT](https://github.com/cvat-ai/cvat/tree/v2.3.0) info, set up with docker compose
|
44 |
+
|
45 |
+
- Server version: 2.3
|
46 |
+
- Core version: 7.3.0
|
47 |
+
- Canvas version: 2.16.1
|
48 |
+
- UI version: 1.45.0
|
49 |
+
|
50 |
+
Dataset related scripts:
|
51 |
+
|
52 |
+
- [coco2yolov5seg.ipynb](../coco2yolov5seg.ipynb): Convert COCO format to YOLOv5 format for segmentation task
|
53 |
+
- [coco_merge_split.ipnb](../coco_merge_split.ipynb): Merge and split COCO format dataset
|
54 |
+
|
55 |
+
## Training / Validation / Export
|
56 |
+
|
57 |
+
Model choice: Prefer [yolov5-seg] over [yolov7-seg] for training/validation/exporting models, performance comparison:
|
58 |
+
|
59 |
+
- yolov5s-seg, fast transfer learning (~5-10 mins for 100 epochs using RTX 3090) and CPU inference
|
60 |
+
- yolov7-seg, seems too heavy (slower inference using CPU)
|
61 |
+
|
62 |
+
Please refer to the repos of yolov5 seg & yolov7 seg for details of training/validation/exporting models.
|
63 |
+
|
64 |
+
[yolov5-seg]: https://github.com/ultralytics/yolov5/blob/master/segment/tutorial.ipynb
|
65 |
+
[yolov7-seg]: https://github.com/WongKinYiu/yolov7/tree/u7/seg
|
66 |
+
|
67 |
+
### yolov5-seg
|
68 |
+
|
69 |
+
Tested commit:
|
70 |
+
|
71 |
+
```shell
|
72 |
+
# Assume work dir is aorta/
|
73 |
+
git clone https://github.com/ultralytics/yolov5
|
74 |
+
cd yolov5
|
75 |
+
git checkout 23c492321290266810e08fa5ee9a23fc9d6a571f
|
76 |
+
git apply ../add_clearml_yolov5.patch
|
77 |
+
```
|
78 |
+
|
79 |
+
As of 2023, yolov5 seg doesn't support ClearML, but there is a [PR](https://github.com/ultralytics/yolov5/pull/10752) for it. So we can manually update these files to use ClearML to track the training process, or apply [add_clearml_yolov5.patch](./add_clearml_yolov5.patch).
|
80 |
+
|
81 |
+
```shell
|
82 |
+
# Example
|
83 |
+
## Original training script
|
84 |
+
python segment/train.py --img 640 --batch 16 --epochs 3 --data coco128-seg.yaml --weights yolov5s-seg.pt --cache
|
85 |
+
|
86 |
+
## Updated training script with ClearML support
|
87 |
+
python segment/train.py --project [clearml_project_name] --name [task_name] --img 640 --batch 16 --epochs 3 --data coco128-seg.yaml --weights yolov5s-seg.pt --cache
|
88 |
+
```
|
89 |
+
|
90 |
+
## Test video
|
91 |
+
|
92 |
+
- Test video: [Demo.mp4](./Demo.mp4)
|
93 |
+
- Tested video (mp4): Converted from the original avi using `ffmpeg`:
|
94 |
+
|
95 |
+
```shell
|
96 |
+
ffmpeg -i "Demo.avi" -vcodec h264 -acodec aac -b:v 500k -strict -2 Demo.mp4`
|
97 |
+
```
|
98 |
+
|
99 |
+
## Demo (POC for 2022 Intel DevCup)
|
100 |
+
|
101 |
+
```shell
|
102 |
+
# run demo, using openvino model
|
103 |
+
python demo.py --video Demo.mp4 --model weights/yolov5s-v2/best_openvino_model/yolov5-640-v2.xml --plot-mask --img-size 640
|
104 |
+
|
105 |
+
# or run the demo using onnx model
|
106 |
+
python demo.py --video Demo.mp4 --model weights/yolov5s-v2/yolov5-640.onnx --plot-mask --img-size 640
|
107 |
+
|
108 |
+
# or run in the headless mode, generating a recording of the demo
|
109 |
+
./demo_headless.sh --video Demo.mp4 --model [path/to/model]
|
110 |
+
```
|
111 |
+
|
112 |
+
## Deploy Pyinstaller EXE
|
113 |
+
|
114 |
+
Only tested on Windows 10:
|
115 |
+
|
116 |
+
```shell
|
117 |
+
pip install pyinstaller==5.9
|
118 |
+
pyinstaller demo.py
|
119 |
+
# (TODO) Replace the following manual steps with pyinstaller --add-data or spec file
|
120 |
+
#
|
121 |
+
# Manual copy files to dist\demo
|
122 |
+
# 1. Copy best_openvino_model folder to dist\demo\
|
123 |
+
# 2. Copy openvino files to dist\demo
|
124 |
+
# C:\Users\sa\miniforge3\envs\echo\Lib\site-packages\openvino\libs
|
125 |
+
# plugins.xml
|
126 |
+
# openvino_ir_frontend.dll
|
127 |
+
# openvino_intel_cpu_plugin.dll
|
128 |
+
# openvino_intel_gpu_plugin.dll
|
129 |
+
```
|
130 |
+
|
131 |
+
Troubleshooting: If the deployed EXE is not working with error `ValueError: --plotlyjs argument is not a valid URL or file path:`, please move the dist folder to another location with no special characters or Chinese in the path. Reference: <https://github.com/plotly/Kaleido/issues/57>
|
132 |
+
|
133 |
+
|
134 |
+
## Paper
|
135 |
+
|
136 |
+
https://www.nature.com/articles/s41746-024-01269-4
|
137 |
+
|
138 |
+
Chiu, IM., Chen, TY., Zheng, YC. et al. Prospective clinical evaluation of deep learning for ultrasonographic screening of abdominal aortic aneurysms. npj Digit. Med. 7, 282 (2024).
|
139 |
+
<!-- #endregion -->
|
140 |
+
|
141 |
+
```python
|
142 |
+
|
143 |
+
```
|
add_clearml_yolov5.patch
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py
|
2 |
+
index 9de1f22..93b9ba2 100644
|
3 |
+
--- a/utils/loggers/__init__.py
|
4 |
+
+++ b/utils/loggers/__init__.py
|
5 |
+
@@ -110,7 +110,7 @@ class Loggers():
|
6 |
+
if clearml and 'clearml' in self.include:
|
7 |
+
try:
|
8 |
+
self.clearml = ClearmlLogger(self.opt, self.hyp)
|
9 |
+
- except Exception:
|
10 |
+
+ except Exception as e:
|
11 |
+
self.clearml = None
|
12 |
+
prefix = colorstr('ClearML: ')
|
13 |
+
LOGGER.warning(f'{prefix}WARNING ⚠️ ClearML is installed but not configured, skipping ClearML logging.'
|
14 |
+
@@ -159,10 +159,11 @@ class Loggers():
|
15 |
+
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
16 |
+
if self.wandb:
|
17 |
+
self.wandb.log({'Labels': [wandb.Image(str(x), caption=x.name) for x in paths]})
|
18 |
+
- # if self.clearml:
|
19 |
+
- # pass # ClearML saves these images automatically using hooks
|
20 |
+
if self.comet_logger:
|
21 |
+
self.comet_logger.on_pretrain_routine_end(paths)
|
22 |
+
+ if self.clearml:
|
23 |
+
+ for path in paths:
|
24 |
+
+ self.clearml.log_plot(title=path.stem, plot_path=path)
|
25 |
+
|
26 |
+
def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
|
27 |
+
log_dict = dict(zip(self.keys[:3], vals))
|
28 |
+
@@ -289,6 +290,8 @@ class Loggers():
|
29 |
+
self.wandb.finish_run()
|
30 |
+
|
31 |
+
if self.clearml and not self.opt.evolve:
|
32 |
+
+ self.clearml.log_summary(dict(zip(self.keys[3:10], results)))
|
33 |
+
+ [self.clearml.log_plot(title=f.stem, plot_path=f) for f in files]
|
34 |
+
self.clearml.task.update_output_model(model_path=str(best if best.exists() else last),
|
35 |
+
name='Best Model',
|
36 |
+
auto_delete_file=False)
|
37 |
+
@@ -303,6 +306,8 @@ class Loggers():
|
38 |
+
self.wandb.wandb_run.config.update(params, allow_val_change=True)
|
39 |
+
if self.comet_logger:
|
40 |
+
self.comet_logger.on_params_update(params)
|
41 |
+
+ if self.clearml:
|
42 |
+
+ self.clearml.task.connect(params)
|
43 |
+
|
44 |
+
|
45 |
+
class GenericLogger:
|
46 |
+
@@ -315,7 +320,7 @@ class GenericLogger:
|
47 |
+
include: loggers to include
|
48 |
+
"""
|
49 |
+
|
50 |
+
- def __init__(self, opt, console_logger, include=('tb', 'wandb')):
|
51 |
+
+ def __init__(self, opt, console_logger, include=('tb', 'wandb', 'clearml')):
|
52 |
+
# init default loggers
|
53 |
+
self.save_dir = Path(opt.save_dir)
|
54 |
+
self.include = include
|
55 |
+
@@ -333,6 +338,22 @@ class GenericLogger:
|
56 |
+
config=opt)
|
57 |
+
else:
|
58 |
+
self.wandb = None
|
59 |
+
+
|
60 |
+
+ if clearml and 'clearml' in self.include:
|
61 |
+
+ try:
|
62 |
+
+ # Hyp is not available in classification mode
|
63 |
+
+ if 'hyp' not in opt:
|
64 |
+
+ hyp = {}
|
65 |
+
+ else:
|
66 |
+
+ hyp = opt.hyp
|
67 |
+
+ self.clearml = ClearmlLogger(opt, hyp)
|
68 |
+
+ except Exception:
|
69 |
+
+ self.clearml = None
|
70 |
+
+ prefix = colorstr('ClearML: ')
|
71 |
+
+ LOGGER.warning(f'{prefix}WARNING ⚠️ ClearML is installed but not configured, skipping ClearML logging.'
|
72 |
+
+ f' See https://github.com/ultralytics/yolov5/tree/master/utils/loggers/clearml#readme')
|
73 |
+
+ else:
|
74 |
+
+ self.clearml = None
|
75 |
+
|
76 |
+
def log_metrics(self, metrics, epoch):
|
77 |
+
# Log metrics dictionary to all loggers
|
78 |
+
@@ -349,6 +370,9 @@ class GenericLogger:
|
79 |
+
|
80 |
+
if self.wandb:
|
81 |
+
self.wandb.log(metrics, step=epoch)
|
82 |
+
+
|
83 |
+
+ if self.clearml:
|
84 |
+
+ self.clearml.log_scalars(metrics, epoch)
|
85 |
+
|
86 |
+
def log_images(self, files, name='Images', epoch=0):
|
87 |
+
# Log images to all loggers
|
88 |
+
@@ -361,6 +385,12 @@ class GenericLogger:
|
89 |
+
|
90 |
+
if self.wandb:
|
91 |
+
self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch)
|
92 |
+
+
|
93 |
+
+ if self.clearml:
|
94 |
+
+ if name == 'Results':
|
95 |
+
+ [self.clearml.log_plot(f.stem, f) for f in files]
|
96 |
+
+ else:
|
97 |
+
+ self.clearml.log_debug_samples(files, title=name)
|
98 |
+
|
99 |
+
def log_graph(self, model, imgsz=(640, 640)):
|
100 |
+
# Log model graph to all loggers
|
101 |
+
@@ -373,11 +403,17 @@ class GenericLogger:
|
102 |
+
art = wandb.Artifact(name=f'run_{wandb.run.id}_model', type='model', metadata=metadata)
|
103 |
+
art.add_file(str(model_path))
|
104 |
+
wandb.log_artifact(art)
|
105 |
+
+
|
106 |
+
+ if self.clearml:
|
107 |
+
+ self.clearml.log_model(model_path=model_path, model_name=model_path.stem)
|
108 |
+
|
109 |
+
def update_params(self, params):
|
110 |
+
# Update the parameters logged
|
111 |
+
if self.wandb:
|
112 |
+
wandb.run.config.update(params, allow_val_change=True)
|
113 |
+
+
|
114 |
+
+ if self.clearml:
|
115 |
+
+ self.clearml.task.connect(params)
|
116 |
+
|
117 |
+
|
118 |
+
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
119 |
+
diff --git a/utils/loggers/clearml/clearml_utils.py b/utils/loggers/clearml/clearml_utils.py
|
120 |
+
index 2764abe..e7525da 100644
|
121 |
+
--- a/utils/loggers/clearml/clearml_utils.py
|
122 |
+
+++ b/utils/loggers/clearml/clearml_utils.py
|
123 |
+
@@ -3,6 +3,9 @@ import glob
|
124 |
+
import re
|
125 |
+
from pathlib import Path
|
126 |
+
|
127 |
+
+import matplotlib.image as mpimg
|
128 |
+
+import matplotlib.pyplot as plt
|
129 |
+
+
|
130 |
+
import numpy as np
|
131 |
+
import yaml
|
132 |
+
|
133 |
+
@@ -79,13 +82,16 @@ class ClearmlLogger:
|
134 |
+
# Maximum number of images to log to clearML per epoch
|
135 |
+
self.max_imgs_to_log_per_epoch = 16
|
136 |
+
# Get the interval of epochs when bounding box images should be logged
|
137 |
+
- self.bbox_interval = opt.bbox_interval
|
138 |
+
+ # Only for detection task though!
|
139 |
+
+ if 'bbox_interval' in opt:
|
140 |
+
+ self.bbox_interval = opt.bbox_interval
|
141 |
+
self.clearml = clearml
|
142 |
+
self.task = None
|
143 |
+
self.data_dict = None
|
144 |
+
if self.clearml:
|
145 |
+
self.task = Task.init(
|
146 |
+
- project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
|
147 |
+
+ # project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
|
148 |
+
+ project_name=opt.project if not str(opt.project).startswith('runs/') else 'YOLOv5',
|
149 |
+
task_name=opt.name if opt.name != 'exp' else 'Training',
|
150 |
+
tags=['YOLOv5'],
|
151 |
+
output_uri=True,
|
152 |
+
@@ -112,6 +118,53 @@ class ClearmlLogger:
|
153 |
+
# Set data to data_dict because wandb will crash without this information and opt is the best way
|
154 |
+
# to give it to them
|
155 |
+
opt.data = self.data_dict
|
156 |
+
+
|
157 |
+
+ def log_scalars(self, metrics, epoch):
|
158 |
+
+ """
|
159 |
+
+ Log scalars/metrics to ClearML
|
160 |
+
+ arguments:
|
161 |
+
+ metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
|
162 |
+
+ epoch (int) iteration number for the current set of metrics
|
163 |
+
+ """
|
164 |
+
+ for k, v in metrics.items():
|
165 |
+
+ title, series = k.split('/')
|
166 |
+
+ self.task.get_logger().report_scalar(title, series, v, epoch)
|
167 |
+
+
|
168 |
+
+ def log_model(self, model_path, model_name, epoch=0):
|
169 |
+
+ """
|
170 |
+
+ Log model weights to ClearML
|
171 |
+
+ arguments:
|
172 |
+
+ model_path (PosixPath or str) Path to the model weights
|
173 |
+
+ model_name (str) Name of the model visible in ClearML
|
174 |
+
+ epoch (int) Iteration / epoch of the model weights
|
175 |
+
+ """
|
176 |
+
+ self.task.update_output_model(model_path=str(model_path),
|
177 |
+
+ name=model_name,
|
178 |
+
+ iteration=epoch,
|
179 |
+
+ auto_delete_file=False)
|
180 |
+
+
|
181 |
+
+ def log_summary(self, metrics):
|
182 |
+
+ """
|
183 |
+
+ Log final metrics to a summary table
|
184 |
+
+ arguments:
|
185 |
+
+ metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
|
186 |
+
+ """
|
187 |
+
+ for k, v in metrics.items():
|
188 |
+
+ self.task.get_logger().report_single_value(k, v)
|
189 |
+
+
|
190 |
+
+ def log_plot(self, title, plot_path):
|
191 |
+
+ """
|
192 |
+
+ Log image as plot in the plot section of ClearML
|
193 |
+
+ arguments:
|
194 |
+
+ title (str) Title of the plot
|
195 |
+
+ plot_path (PosixPath or str) Path to the saved image file
|
196 |
+
+ """
|
197 |
+
+ img = mpimg.imread(plot_path)
|
198 |
+
+ fig = plt.figure()
|
199 |
+
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
200 |
+
+ ax.imshow(img)
|
201 |
+
+
|
202 |
+
+ self.task.get_logger().report_matplotlib_figure(title, "", figure=fig, report_interactive=False)
|
203 |
+
|
204 |
+
def log_debug_samples(self, files, title='Debug Samples'):
|
205 |
+
"""
|
206 |
+
@@ -126,7 +179,8 @@ class ClearmlLogger:
|
207 |
+
it = re.search(r'_batch(\d+)', f.name)
|
208 |
+
iteration = int(it.groups()[0]) if it else 0
|
209 |
+
self.task.get_logger().report_image(title=title,
|
210 |
+
- series=f.name.replace(it.group(), ''),
|
211 |
+
+ # series=f.name.replace(it.group(), ''),
|
212 |
+
+ series=f.name.replace(f"_batch{iteration}", ''),
|
213 |
+
local_path=str(f),
|
214 |
+
iteration=iteration)
|
215 |
+
|
dataset.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typer
|
2 |
+
import fiftyone as fo
|
3 |
+
from fiftyone import ViewField as F
|
4 |
+
from pathlib import Path
|
5 |
+
from pycocotools.coco import COCO
|
6 |
+
from loguru import logger
|
7 |
+
import cv2
|
8 |
+
import shutil
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
from collections import defaultdict
|
12 |
+
import csv
|
13 |
+
|
14 |
+
|
15 |
+
DEFAULT_EXCLUDE_NAME = "Ellen"
|
16 |
+
DEFAULT_INS_TRAIN = "instances_Train.json"
|
17 |
+
DEFAULT_INS_TEST = "instances_Test.json"
|
18 |
+
|
19 |
+
app = typer.Typer()
|
20 |
+
|
21 |
+
|
22 |
+
@app.command()
|
23 |
+
def newsplit(
|
24 |
+
in_dir: str,
|
25 |
+
train_json=DEFAULT_INS_TRAIN,
|
26 |
+
test_json=DEFAULT_INS_TEST,
|
27 |
+
exclude_name=DEFAULT_EXCLUDE_NAME,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Merge the train and test datasets,
|
31 |
+
and then split them into new train/test by leaving one person out.
|
32 |
+
"""
|
33 |
+
|
34 |
+
# load the dataset
|
35 |
+
logger.info("Loading datasets...")
|
36 |
+
ds1 = fo.Dataset.from_dir(
|
37 |
+
dataset_type=fo.types.COCODetectionDataset,
|
38 |
+
data_path=Path(in_dir) / "images",
|
39 |
+
labels_path=Path(in_dir) / "annotations" / train_json,
|
40 |
+
)
|
41 |
+
ds2 = fo.Dataset.from_dir(
|
42 |
+
dataset_type=fo.types.COCODetectionDataset,
|
43 |
+
data_path=Path(in_dir) / "images",
|
44 |
+
labels_path=Path(in_dir) / "annotations" / test_json,
|
45 |
+
)
|
46 |
+
|
47 |
+
logger.info(f"[Before] Num samples in train: {len(ds1)}")
|
48 |
+
logger.info(f"[Before] Num samples in test: {len(ds2)}")
|
49 |
+
|
50 |
+
# merge the datasets
|
51 |
+
ds1.merge_samples(ds2)
|
52 |
+
|
53 |
+
# generate the new split
|
54 |
+
logger.info(f"Excluding name in filepath as train set: {exclude_name}")
|
55 |
+
new_train_view = ds1.match(~F("filepath").re_match(exclude_name))
|
56 |
+
new_test_view = ds1.match(F("filepath").re_match(exclude_name))
|
57 |
+
assert len(new_train_view) + len(new_test_view) == len(ds1)
|
58 |
+
logger.info(f"[After] Num samples in train: {len(new_train_view)}")
|
59 |
+
logger.info(f"[After] Num samples in test: {len(new_test_view)}")
|
60 |
+
train_counts = new_train_view.count_values("detections.detections.label")
|
61 |
+
test_counts = new_test_view.count_values("detections.detections.label")
|
62 |
+
logger.info(f"[After] Train counts: {train_counts}")
|
63 |
+
logger.info(f"[After] Test counts: {test_counts}")
|
64 |
+
|
65 |
+
# export the new split
|
66 |
+
logger.info("Exporting new train/test...")
|
67 |
+
new_train_p = Path(in_dir) / "annotations" / f"new_train_no-{exclude_name}.json"
|
68 |
+
new_test_p = Path(in_dir) / "annotations" / f"new_test_{exclude_name}.json"
|
69 |
+
new_train_view.export(
|
70 |
+
dataset_type=fo.types.COCODetectionDataset,
|
71 |
+
labels_path=new_train_p,
|
72 |
+
label_field="segmentations",
|
73 |
+
classes=ds1.default_classes,
|
74 |
+
abs_paths=True,
|
75 |
+
)
|
76 |
+
new_test_view.export(
|
77 |
+
dataset_type=fo.types.COCODetectionDataset,
|
78 |
+
labels_path=new_test_p,
|
79 |
+
label_field="segmentations",
|
80 |
+
classes=ds2.default_classes,
|
81 |
+
abs_paths=True,
|
82 |
+
)
|
83 |
+
logger.info(f"Exported new train: {new_train_p}")
|
84 |
+
logger.info(f"Exported new test: {new_test_p}")
|
85 |
+
|
86 |
+
|
87 |
+
def _normalize(img_size, xy_s):
|
88 |
+
assert len(xy_s) % 2 == 0
|
89 |
+
normalized_xy_s = []
|
90 |
+
dw = 1.0 / (img_size[0])
|
91 |
+
dh = 1.0 / (img_size[1])
|
92 |
+
for i in range(len(xy_s)):
|
93 |
+
p = xy_s[i]
|
94 |
+
p = p * dw if i % 2 == 0 else p * dh
|
95 |
+
assert p <= 1.0 and p >= 0.0, f"{p} should < 1 and > 0"
|
96 |
+
normalized_xy_s.append(p)
|
97 |
+
return normalized_xy_s
|
98 |
+
|
99 |
+
|
100 |
+
def _coco2yolo(coco_img_dir, coco_json_path, out_dir, bbox_only=False, rois=None):
|
101 |
+
logger.info(f"Reading {Path(coco_json_path).name}...")
|
102 |
+
coco = COCO(coco_json_path)
|
103 |
+
|
104 |
+
cats = coco.loadCats(coco.getCatIds())
|
105 |
+
cats = sorted(cats, key=lambda x: x["id"], reverse=False)
|
106 |
+
assert cats[0]["id"] == 1, f"Assume cat id starts from 1, but got {cats[0]['id']}"
|
107 |
+
logger.info(f"{len(cats)} categories: {[cat['name'] for cat in cats]}")
|
108 |
+
|
109 |
+
img_ids = coco.getImgIds()
|
110 |
+
prefix = Path(coco_json_path).stem.split("_")[-1].lower() # either train or test
|
111 |
+
|
112 |
+
# create output directories
|
113 |
+
target_txt_r = Path(out_dir) / prefix / "labels"
|
114 |
+
target_img_r = Path(out_dir) / prefix / "images"
|
115 |
+
target_txt_r.mkdir(parents=True, exist_ok=False)
|
116 |
+
target_img_r.mkdir(parents=True, exist_ok=False)
|
117 |
+
|
118 |
+
logger.info(f"Num of imgs: {len(img_ids)}")
|
119 |
+
|
120 |
+
n_imgs_no_annos = 0
|
121 |
+
num_zero_area = 0
|
122 |
+
for img_id in img_ids:
|
123 |
+
img = coco.loadImgs(img_id)[0]
|
124 |
+
img_p = Path(coco_img_dir) / img["file_name"]
|
125 |
+
assert img_p.exists(), f"{img_p} does not exist"
|
126 |
+
|
127 |
+
anno_ids = coco.getAnnIds(imgIds=img["id"])
|
128 |
+
annos = coco.loadAnns(anno_ids)
|
129 |
+
|
130 |
+
new_filename = f"{img['id']}_{img_p.stem}"
|
131 |
+
|
132 |
+
out_img_p = target_img_r / (new_filename + img_p.suffix)
|
133 |
+
|
134 |
+
# get roi for the image if any
|
135 |
+
im_cv = cv2.imread(img_p.as_posix())
|
136 |
+
im_width, im_height = im_cv.shape[1], im_cv.shape[0]
|
137 |
+
roi = rois[(im_width, im_height)] if rois is not None else None
|
138 |
+
has_roi = (rois is not None) and (roi is not None) and len(roi) == 4
|
139 |
+
if not has_roi:
|
140 |
+
# copy image to target dir
|
141 |
+
shutil.copy(img_p, out_img_p)
|
142 |
+
else:
|
143 |
+
# crop the image to target dir
|
144 |
+
assert len(roi) == 4, f"ROI should have 4 values, but got {roi}"
|
145 |
+
cropped_img = im_cv[roi[1] : roi[1] + roi[3], roi[0] : roi[0] + roi[2]]
|
146 |
+
cv2.imwrite(out_img_p.as_posix(), cropped_img)
|
147 |
+
|
148 |
+
# bg imgs: only need to copy img, no need to create label file
|
149 |
+
if len(annos) == 0:
|
150 |
+
n_imgs_no_annos += 1
|
151 |
+
continue
|
152 |
+
|
153 |
+
# create the label txt file
|
154 |
+
txt_p = Path(target_txt_r) / (new_filename + ".txt")
|
155 |
+
if txt_p.exists():
|
156 |
+
logger.warning(f"{txt_p} already exists, {img_p} skipped")
|
157 |
+
txt_f = open(txt_p, "w")
|
158 |
+
img = cv2.imread(img_p.as_posix())
|
159 |
+
h, w, _ = img.shape
|
160 |
+
|
161 |
+
# generate txt file for each image
|
162 |
+
for ann in annos:
|
163 |
+
cls_id = ann["category_id"] - 1 # yolov5 uses zero-based class idx
|
164 |
+
|
165 |
+
# region bbox, for object detection
|
166 |
+
if bbox_only:
|
167 |
+
bbox = ann["bbox"]
|
168 |
+
# convert coco to yolo: top-x, top-y, w, h -> center-x, center-y, w, h
|
169 |
+
bbox_yolo = [
|
170 |
+
bbox[0] + bbox[2] / 2,
|
171 |
+
bbox[1] + bbox[3] / 2,
|
172 |
+
bbox[2],
|
173 |
+
bbox[3],
|
174 |
+
]
|
175 |
+
n_bbox_p = " ".join([str(a) for a in _normalize((w, h), bbox_yolo)])
|
176 |
+
txt_f.write(f"{cls_id} {n_bbox_p}{os.linesep}")
|
177 |
+
continue
|
178 |
+
# endregion
|
179 |
+
|
180 |
+
# region seg, for instance segmentation
|
181 |
+
seg = ann["segmentation"]
|
182 |
+
if len(seg) > 1:
|
183 |
+
# TODO: Investigate why sometimes there are multiple segs
|
184 |
+
logger.warning(f"Skip {img_p} with {len(seg)} segs of {ann}")
|
185 |
+
continue
|
186 |
+
|
187 |
+
if len(seg) == 1:
|
188 |
+
xy_s = seg[0]
|
189 |
+
# handle roi if any
|
190 |
+
if has_roi:
|
191 |
+
xy_s = [xy - roi[i % 2] for i, xy in enumerate(xy_s)]
|
192 |
+
w, h = roi[2], roi[3]
|
193 |
+
# remove the points outside of roi
|
194 |
+
new_xy_s = []
|
195 |
+
for i in range(0, len(xy_s), 2):
|
196 |
+
x, y = xy_s[i], xy_s[i + 1]
|
197 |
+
if x >= 0 and x <= w and y >= 0 and y <= h:
|
198 |
+
new_xy_s.extend([x, y])
|
199 |
+
xy_s = new_xy_s
|
200 |
+
n_xy_s = _normalize((w, h), xy_s)
|
201 |
+
seg_p = " ".join([str(a) for a in n_xy_s])
|
202 |
+
txt_f.write(f"{cls_id} {seg_p}{os.linesep}")
|
203 |
+
# endregion
|
204 |
+
|
205 |
+
# region keypoint, for pose estimation
|
206 |
+
if "keypoints" in ann:
|
207 |
+
# skip area 0 keypoints which could cause yolov8 training error
|
208 |
+
if int(ann["area"]) == 0:
|
209 |
+
num_zero_area += 1
|
210 |
+
continue
|
211 |
+
kps = ann["keypoints"]
|
212 |
+
bbox = ann["bbox"]
|
213 |
+
# convert coco to yolo: top-x, top-y, w, h -> center-x, center-y, w, h
|
214 |
+
bbox_yolo = [
|
215 |
+
bbox[0] + bbox[2] / 2,
|
216 |
+
bbox[1] + bbox[3] / 2,
|
217 |
+
bbox[2],
|
218 |
+
bbox[3],
|
219 |
+
]
|
220 |
+
n_bbox_p = " ".join([str(a) for a in _normalize((w, h), bbox_yolo)])
|
221 |
+
# normalize x,y of each keypoint and keep visibility as is
|
222 |
+
n_kp = []
|
223 |
+
for i in range(0, len(kps), 3):
|
224 |
+
n_kp.append(kps[i] / w)
|
225 |
+
n_kp.append(kps[i + 1] / h)
|
226 |
+
n_kp.append(kps[i + 2])
|
227 |
+
n_kp_p = " ".join([str(a) for a in n_kp])
|
228 |
+
txt_f.write(f"{cls_id} {n_bbox_p} {n_kp_p}{os.linesep}")
|
229 |
+
# endregion
|
230 |
+
txt_f.close()
|
231 |
+
# remove empty label file which has no annos
|
232 |
+
if txt_p.stat().st_size == 0:
|
233 |
+
txt_p.unlink()
|
234 |
+
n_imgs_no_annos += 1
|
235 |
+
empty_ratio = 100 * float(n_imgs_no_annos) / len(img_ids)
|
236 |
+
n_imgs_anns = len(img_ids) - n_imgs_no_annos
|
237 |
+
logger.info(f"# imgs w anns: {n_imgs_anns} {(100-empty_ratio):.2f}%")
|
238 |
+
logger.info(f"# imgs w/o anns: {n_imgs_no_annos} {empty_ratio:.2f}%")
|
239 |
+
logger.info(f"# zero area kps: {num_zero_area}")
|
240 |
+
txts = [f for f in target_txt_r.iterdir() if f.is_file()]
|
241 |
+
imgs = [f for f in target_img_r.iterdir() if f.is_file()]
|
242 |
+
assert (len(txts) + n_imgs_no_annos) == len(imgs) == len(img_ids)
|
243 |
+
return target_img_r
|
244 |
+
|
245 |
+
|
246 |
+
@app.command(help="Convert COCO dataset to YOLOv5 format")
|
247 |
+
def coco2yolov5(
|
248 |
+
in_dir: str,
|
249 |
+
out_dir: str,
|
250 |
+
split_val_ratio: float = 0.2,
|
251 |
+
seed: int = 42,
|
252 |
+
bbox_only: bool = False,
|
253 |
+
crop_roi_file: str = None,
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Convert COCO dataset to YOLOv5 format.
|
257 |
+
Support 3 task types: object detection, instance segmentation, pose estimation.
|
258 |
+
|
259 |
+
YOLOv5 seg labels are the same as detection labels, using txt files with one object per line.
|
260 |
+
The difference is that instead of "class, xywh" they are "class xy1, xy2, xy3,...".
|
261 |
+
Ref: https://github.com/ultralytics/yolov5/issues/10161#issuecomment-1315672357
|
262 |
+
|
263 |
+
YOLOv5 keypoint labels is using txt files with one object per line.
|
264 |
+
class cx cy w h x1 y1 v1 ... xn yn vn
|
265 |
+
All coordinates are normalized by image width and height.
|
266 |
+
vn (visibility): 0, 1, or 2 => not labeled, labeled but invisible, labeled and visible
|
267 |
+
Ref: https://github.com/ultralytics/ultralytics/issues/1870#issuecomment-1498909244
|
268 |
+
Example: https://ultralytics.com/assets/coco8-pose.zip
|
269 |
+
"""
|
270 |
+
if Path(out_dir).exists():
|
271 |
+
delete = typer.confirm(f"{out_dir} alread exists. Are you sure to delete it?")
|
272 |
+
if not delete:
|
273 |
+
logger.info("Not deleting")
|
274 |
+
raise typer.Abort()
|
275 |
+
shutil.rmtree(out_dir)
|
276 |
+
logger.info(f"Deleted {Path(out_dir).name}")
|
277 |
+
|
278 |
+
ann_dir_p = Path(in_dir) / "annotations"
|
279 |
+
img_dir_p = Path(in_dir) / "images"
|
280 |
+
assert ann_dir_p.exists(), f"{ann_dir_p} does not exist"
|
281 |
+
assert img_dir_p.exists(), f"{img_dir_p} does not exist"
|
282 |
+
|
283 |
+
# try to find the json files of train & test in annotations dir
|
284 |
+
train_json_p = None
|
285 |
+
test_json_p = None
|
286 |
+
for f in ann_dir_p.iterdir():
|
287 |
+
if f.stem.lower().endswith("train"):
|
288 |
+
train_json_p = f
|
289 |
+
logger.info(f"Found train json: {f.name}")
|
290 |
+
elif f.stem.lower().endswith("test"):
|
291 |
+
test_json_p = f
|
292 |
+
logger.info(f"Found test json: {f.name}")
|
293 |
+
# must have train, while test is optional
|
294 |
+
assert train_json_p is not None, f"Cannot find train json in {ann_dir_p}"
|
295 |
+
do_split = False
|
296 |
+
if test_json_p is None:
|
297 |
+
logger.warning("Cannot find test json in [in_dir]/annotations")
|
298 |
+
do_split = typer.confirm("Do you want to split val from train?")
|
299 |
+
|
300 |
+
# region handle ROIs
|
301 |
+
rois = None
|
302 |
+
if crop_roi_file is not None:
|
303 |
+
roi_csv_p = Path(crop_roi_file)
|
304 |
+
assert roi_csv_p.exists(), f"{roi_csv_p} does not exist"
|
305 |
+
# read ROIs from csv, each image size should have one ROI
|
306 |
+
rois = defaultdict(lambda: [], {})
|
307 |
+
with open(roi_csv_p, "r") as f:
|
308 |
+
for roi in csv.DictReader(f):
|
309 |
+
ori_width = int(roi["ori_width"])
|
310 |
+
ori_height = int(roi["ori_height"])
|
311 |
+
roi_x = int(roi["roi_x"])
|
312 |
+
roi_y = int(roi["roi_y"])
|
313 |
+
roi_width = int(roi["roi_width"])
|
314 |
+
roi_height = int(roi["roi_height"])
|
315 |
+
|
316 |
+
key = (ori_width, ori_height)
|
317 |
+
assert key not in rois, f"Duplicate ROI for {key}"
|
318 |
+
rois[key] = [roi_x, roi_y, roi_width, roi_height]
|
319 |
+
# endregion
|
320 |
+
|
321 |
+
yolo_train_img_dir = None
|
322 |
+
yolo_test_img_dir = None
|
323 |
+
yolo_train_img_dir = _coco2yolo(img_dir_p, train_json_p, out_dir, bbox_only, rois)
|
324 |
+
if test_json_p is not None:
|
325 |
+
yolo_test_img_dir = _coco2yolo(img_dir_p, test_json_p, out_dir, bbox_only, rois)
|
326 |
+
|
327 |
+
if do_split:
|
328 |
+
yolo_test_img_dir = Path(out_dir) / "val" / "images"
|
329 |
+
# randomly select 20% of train images
|
330 |
+
train_imgs = [f for f in yolo_train_img_dir.iterdir() if f.is_file()]
|
331 |
+
n_test = int(len(train_imgs) * split_val_ratio)
|
332 |
+
logger.info(f"Split ratio {split_val_ratio}: {n_test} test images from train")
|
333 |
+
# set random seed to make sure the same images are selected
|
334 |
+
random.seed(seed)
|
335 |
+
test_imgs = random.sample(train_imgs, n_test)
|
336 |
+
# move test images to val/images
|
337 |
+
yolo_test_img_dir.mkdir(parents=True, exist_ok=True)
|
338 |
+
for f in test_imgs:
|
339 |
+
shutil.move(str(f), str(yolo_test_img_dir))
|
340 |
+
# move labels of test images to val/labels
|
341 |
+
yolo_test_label_dir = Path(out_dir) / "val" / "labels"
|
342 |
+
yolo_test_label_dir.mkdir(parents=True, exist_ok=True)
|
343 |
+
for f in test_imgs:
|
344 |
+
label_f = yolo_train_img_dir.parent / "labels" / f"{f.stem}.txt"
|
345 |
+
if label_f.exists():
|
346 |
+
shutil.move(str(label_f), str(yolo_test_label_dir))
|
347 |
+
|
348 |
+
# region create yaml file
|
349 |
+
|
350 |
+
logger.info(f"Reading {Path(train_json_p).name}...")
|
351 |
+
train_coco = COCO(train_json_p)
|
352 |
+
train_cats = train_coco.loadCats(train_coco.getCatIds())
|
353 |
+
num_kps = [
|
354 |
+
len(c["keypoints"])
|
355 |
+
for c in train_cats
|
356 |
+
if "keypoints" in c and len(c["keypoints"]) > 0
|
357 |
+
]
|
358 |
+
# check if all categories have the same number of keypoints
|
359 |
+
if len(num_kps) > 0:
|
360 |
+
assert len(set(num_kps)) == 1, "Categories have different number of keypoints"
|
361 |
+
logger.info(f"Number of keypoints: {set(num_kps)}")
|
362 |
+
train_cats = [c["name"] for c in train_cats]
|
363 |
+
# ensure having the same categories in the json of train & test
|
364 |
+
# test_coco = COCO(test_json_p)
|
365 |
+
# test_cats = test_coco.loadCats(test_coco.getCatIds())
|
366 |
+
# test_cats = sorted(test_cats, key=lambda x: x["id"], reverse=False)
|
367 |
+
# test_cats = [c["name"] for c in test_cats]
|
368 |
+
# assert ",".join(train_cats) == ",".join(test_cats), "Categories mismatch"
|
369 |
+
|
370 |
+
out_config_file = Path(out_dir) / "data.yaml"
|
371 |
+
with open(out_config_file, "w") as f:
|
372 |
+
if len(num_kps) > 0:
|
373 |
+
f.write(f"kpt_shape: [{num_kps[0]},3]" + os.linesep)
|
374 |
+
assert num_kps[0] == 1, "Only support 1 keypoint for now"
|
375 |
+
f.write("flip_idx: [0]" + os.linesep)
|
376 |
+
f.write("names:" + os.linesep)
|
377 |
+
for c in train_cats:
|
378 |
+
f.write(f"- {c}" + os.linesep)
|
379 |
+
f.write(f"nc: {len(train_cats)}" + os.linesep)
|
380 |
+
f.write(f"path: {Path(out_dir).absolute()}" + os.linesep)
|
381 |
+
train_rel_path = f"{yolo_train_img_dir.parent.name}/{yolo_train_img_dir.name}"
|
382 |
+
f.write(f"train: {train_rel_path}" + os.linesep)
|
383 |
+
if yolo_test_img_dir is not None:
|
384 |
+
val_rel_path = f"{yolo_test_img_dir.parent.name}/{yolo_test_img_dir.name}"
|
385 |
+
f.write(f"val: {val_rel_path}" + os.linesep)
|
386 |
+
|
387 |
+
logger.info(f"Config file saved: {out_config_file}")
|
388 |
+
# endregion
|
389 |
+
|
390 |
+
logger.info("Done ✅")
|
391 |
+
|
392 |
+
|
393 |
+
@app.command(help="List all image sizes and counts in a directory recursively")
|
394 |
+
def list_img_sizes(
|
395 |
+
in_dir: str = typer.Argument(..., help="Input directory"),
|
396 |
+
):
|
397 |
+
in_dir_p = Path(in_dir)
|
398 |
+
assert in_dir_p.exists(), f"{in_dir_p} does not exist"
|
399 |
+
assert in_dir_p.is_dir(), f"{in_dir_p} is not a directory"
|
400 |
+
|
401 |
+
ds = fo.Dataset.from_images_dir(in_dir_p)
|
402 |
+
ds.compute_metadata()
|
403 |
+
|
404 |
+
logger.info(f"Found {len(ds)} images in {in_dir_p}")
|
405 |
+
|
406 |
+
# count number of images for each size
|
407 |
+
sizes = defaultdict(lambda: 0, {})
|
408 |
+
for sample in ds:
|
409 |
+
metadata = sample.metadata
|
410 |
+
width = metadata.width
|
411 |
+
height = metadata.height
|
412 |
+
sizes[(width, height)] += 1
|
413 |
+
# sort with the most frequent size first
|
414 |
+
sizes = dict(sorted(sizes.items(), key=lambda x: x[1], reverse=True))
|
415 |
+
for k, v in sizes.items():
|
416 |
+
# find one example image for each size
|
417 |
+
sample = ds.match({"metadata.width": k[0], "metadata.height": k[1]}).first()
|
418 |
+
print(f"Size (w, h) {k}: {v} image(s), e.g., {sample.filepath}")
|
419 |
+
|
420 |
+
|
421 |
+
@app.command(help="Crop images in a directory recursively with ROIs from csv")
|
422 |
+
def crop_imgs(
|
423 |
+
in_dir: str = typer.Argument(..., help="Input directory"),
|
424 |
+
roi_csv: str = typer.Argument(..., help="CSV file containing ROIs"),
|
425 |
+
):
|
426 |
+
in_dir_p = Path(in_dir)
|
427 |
+
assert in_dir_p.exists(), f"{in_dir_p} does not exist"
|
428 |
+
assert in_dir_p.is_dir(), f"{in_dir_p} is not a directory"
|
429 |
+
|
430 |
+
roi_csv_p = Path(roi_csv)
|
431 |
+
assert roi_csv_p.exists(), f"{roi_csv_p} does not exist"
|
432 |
+
|
433 |
+
# read ROIs from csv, each image size should have one ROI
|
434 |
+
rois = defaultdict(lambda: [], {})
|
435 |
+
with open(roi_csv_p, "r") as f:
|
436 |
+
for roi in csv.DictReader(f):
|
437 |
+
ori_width = int(roi["ori_width"])
|
438 |
+
ori_height = int(roi["ori_height"])
|
439 |
+
roi_x = int(roi["roi_x"])
|
440 |
+
roi_y = int(roi["roi_y"])
|
441 |
+
roi_width = int(roi["roi_width"])
|
442 |
+
roi_height = int(roi["roi_height"])
|
443 |
+
|
444 |
+
key = (ori_width, ori_height)
|
445 |
+
assert key not in rois, f"Duplicate ROI for {key}"
|
446 |
+
rois[key] = [roi_x, roi_y, roi_width, roi_height]
|
447 |
+
|
448 |
+
# read and crop images
|
449 |
+
# write the cropped images to a new directory
|
450 |
+
out_dir_p = in_dir_p.parent / f"{in_dir_p.name}_cropped"
|
451 |
+
Path(out_dir_p).mkdir(parents=True, exist_ok=True)
|
452 |
+
ds = fo.Dataset.from_images_dir(in_dir_p)
|
453 |
+
logger.info(f"Found {len(ds)} images in {in_dir_p}")
|
454 |
+
for sample in ds:
|
455 |
+
img_path = sample.filepath
|
456 |
+
|
457 |
+
# read and crop the image
|
458 |
+
img = cv2.imread(img_path)
|
459 |
+
width, height = img.shape[1], img.shape[0]
|
460 |
+
roi = rois[(width, height)]
|
461 |
+
cropped_img = img[roi[1] : roi[1] + roi[3], roi[0] : roi[0] + roi[2]]
|
462 |
+
|
463 |
+
# keep the original folder structure
|
464 |
+
out_img_p = out_dir_p / Path(img_path).relative_to(in_dir_p.absolute())
|
465 |
+
# create the subfolder if not exist
|
466 |
+
if not out_img_p.parent.exists():
|
467 |
+
out_img_p.parent.mkdir(parents=True, exist_ok=True)
|
468 |
+
cv2.imwrite(str(out_img_p), cropped_img)
|
469 |
+
logger.info(f"Cropped images saved to {out_dir_p}")
|
470 |
+
|
471 |
+
|
472 |
+
@app.command(help="Count num of images without aorta annotations")
|
473 |
+
def count_n_imgs_no_aorta(
|
474 |
+
in_coco_json_p: str = typer.Argument(..., help="Input coco json file"),
|
475 |
+
aorta_cat_name: str = typer.Argument("aorta", help="Name of aorta category"),
|
476 |
+
):
|
477 |
+
logger.info(f"Reading {Path(in_coco_json_p).name}...")
|
478 |
+
assert Path(in_coco_json_p).exists(), f"{in_coco_json_p} does not exist"
|
479 |
+
coco = COCO(in_coco_json_p)
|
480 |
+
|
481 |
+
cats = coco.loadCats(coco.getCatIds())
|
482 |
+
cats = sorted(cats, key=lambda x: x["id"], reverse=False)
|
483 |
+
# find the category id of aorta
|
484 |
+
aorta_cat_id = None
|
485 |
+
for cat in cats:
|
486 |
+
if cat["name"] == aorta_cat_name:
|
487 |
+
aorta_cat_id = cat["id"]
|
488 |
+
break
|
489 |
+
assert aorta_cat_id is not None, f"Cannot find {aorta_cat_name} in {in_coco_json_p}"
|
490 |
+
logger.info(f"Found {aorta_cat_name} with id {aorta_cat_id}")
|
491 |
+
|
492 |
+
n_img_no_aorta = 0
|
493 |
+
for img_id in coco.getImgIds():
|
494 |
+
anno_ids = coco.getAnnIds(imgIds=img_id)
|
495 |
+
annos = coco.loadAnns(anno_ids)
|
496 |
+
has_aorta = False
|
497 |
+
for anno in annos:
|
498 |
+
if anno["category_id"] == aorta_cat_id:
|
499 |
+
has_aorta = True
|
500 |
+
break
|
501 |
+
if not has_aorta:
|
502 |
+
n_img_no_aorta += 1
|
503 |
+
logger.info(f"Found {n_img_no_aorta} images without {aorta_cat_name}")
|
504 |
+
|
505 |
+
|
506 |
+
@app.command(help="Remove non-aorta annotations from a YOLOv5 dataset")
|
507 |
+
def keep_only_aorta_labels_in_yolo(
|
508 |
+
in_dir: str = typer.Argument(..., help="Input label directory"),
|
509 |
+
aorta_class_id: int = typer.Argument(0, help="Class id of aorta"),
|
510 |
+
):
|
511 |
+
txts = list(Path(in_dir).glob("*.txt"))
|
512 |
+
logger.info(f"Found {len(txts)} txt files in {in_dir}")
|
513 |
+
for txt_p in txts:
|
514 |
+
ori_lines, new_lines = [], []
|
515 |
+
with open(txt_p, "r") as f:
|
516 |
+
ori_lines = f.readlines()
|
517 |
+
for line in ori_lines:
|
518 |
+
nums = line.split(" ")
|
519 |
+
if int(nums[0]) == aorta_class_id:
|
520 |
+
new_lines.append(line)
|
521 |
+
with open(txt_p, "w") as new_f:
|
522 |
+
new_f.writelines(new_lines)
|
523 |
+
|
524 |
+
|
525 |
+
if __name__ == "__main__":
|
526 |
+
app()
|
demo.bat
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
REM "Please change the path to your own path"
|
4 |
+
cd "C:\Users\chenp\Downloads\aorta_demo_v3"
|
5 |
+
|
6 |
+
REM "Please change the path to your own path"
|
7 |
+
call C:\ProgramData\miniconda3\Scripts\activate.bat
|
8 |
+
|
9 |
+
call conda activate echo
|
10 |
+
call python demo.py --device GPU --jobs 2
|
11 |
+
|
12 |
+
pause
|
demo.py
ADDED
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
from time import perf_counter
|
5 |
+
import argparse
|
6 |
+
from loguru import logger
|
7 |
+
import os
|
8 |
+
|
9 |
+
from predict import Model
|
10 |
+
|
11 |
+
from datetime import datetime
|
12 |
+
from scipy import signal
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
import numpy as np
|
15 |
+
import io
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
from PySide6.QtCore import Qt, QThread, Signal, Slot
|
20 |
+
from PySide6.QtGui import QImage, QPixmap
|
21 |
+
from PySide6.QtWidgets import (
|
22 |
+
QApplication,
|
23 |
+
QHBoxLayout,
|
24 |
+
QLabel,
|
25 |
+
QMainWindow,
|
26 |
+
QPushButton,
|
27 |
+
QSizePolicy,
|
28 |
+
QVBoxLayout,
|
29 |
+
QWidget,
|
30 |
+
)
|
31 |
+
|
32 |
+
# for telemed
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
import ctypes
|
35 |
+
from ctypes import *
|
36 |
+
|
37 |
+
# 720p
|
38 |
+
video_w = 1280
|
39 |
+
video_h = 720
|
40 |
+
|
41 |
+
|
42 |
+
# Copy from detection.py from telemed sample code
|
43 |
+
class Telemed:
|
44 |
+
def __init__(self):
|
45 |
+
# starting copy from the origianl main
|
46 |
+
|
47 |
+
# Setting ultrasound size
|
48 |
+
# w = 512
|
49 |
+
# h = 512
|
50 |
+
w = 640
|
51 |
+
h = 640
|
52 |
+
|
53 |
+
# Load dll
|
54 |
+
# usgfw2 = cdll.LoadLibrary('./usgfw2wrapper_C++_sources/usgfw2wrapper/x64/Release/usgfw2wrapper.dll')
|
55 |
+
usgfw2 = cdll.LoadLibrary("./usgfw2wrapper.dll")
|
56 |
+
|
57 |
+
# Ultrasound initialize
|
58 |
+
usgfw2.on_init()
|
59 |
+
ERR = usgfw2.init_ultrasound_usgfw2()
|
60 |
+
|
61 |
+
# Check probe
|
62 |
+
if ERR == 2:
|
63 |
+
logger.error("Main Usgfw2 library object not created")
|
64 |
+
usgfw2.Close_and_release()
|
65 |
+
sys.exit()
|
66 |
+
|
67 |
+
ERR = usgfw2.find_connected_probe()
|
68 |
+
|
69 |
+
if ERR != 101:
|
70 |
+
logger.error("Probe not detected")
|
71 |
+
usgfw2.Close_and_release()
|
72 |
+
sys.exit()
|
73 |
+
|
74 |
+
ERR = usgfw2.data_view_function()
|
75 |
+
|
76 |
+
if ERR < 0:
|
77 |
+
logger.error(
|
78 |
+
"Main ultrasound scanning object for selected probe not created"
|
79 |
+
)
|
80 |
+
sys.exit()
|
81 |
+
|
82 |
+
ERR = usgfw2.mixer_control_function(0, 0, w, h, 0, 0, 0)
|
83 |
+
if ERR < 0:
|
84 |
+
logger.error("B mixer control not returned")
|
85 |
+
sys.exit()
|
86 |
+
|
87 |
+
# Probe setting
|
88 |
+
res_X = ctypes.c_float(0.0)
|
89 |
+
res_Y = ctypes.c_float(0.0)
|
90 |
+
usgfw2.get_resolution(ctypes.pointer(res_X), ctypes.pointer(res_Y))
|
91 |
+
|
92 |
+
X_axis = np.zeros(shape=(w))
|
93 |
+
Y_axis = np.zeros(shape=(h))
|
94 |
+
if w % 2 == 0:
|
95 |
+
k = 0
|
96 |
+
for i in range(-w // 2, w // 2 + 1):
|
97 |
+
if i < 0:
|
98 |
+
j = i + 0.5
|
99 |
+
X_axis[k] = j * res_X.value
|
100 |
+
k = k + 1
|
101 |
+
else:
|
102 |
+
if i > 0:
|
103 |
+
j = i - 0.5
|
104 |
+
X_axis[k] = j * res_X.value
|
105 |
+
k = k + 1
|
106 |
+
|
107 |
+
else:
|
108 |
+
for i in range(-w // 2, w // 2):
|
109 |
+
X_axis[i + w / 2 + 1] = i * res_X.value
|
110 |
+
|
111 |
+
for i in range(0, h - 1):
|
112 |
+
Y_axis[i] = i * res_Y.value
|
113 |
+
|
114 |
+
old_resolution_x = res_X.value
|
115 |
+
old_resolution_y = res_X.value
|
116 |
+
|
117 |
+
# Image setting
|
118 |
+
p_array = (ctypes.c_uint * w * h * 4)()
|
119 |
+
|
120 |
+
fig, ax = plt.subplots()
|
121 |
+
usgfw2.return_pixel_values(ctypes.pointer(p_array))
|
122 |
+
buffer_as_numpy_array = np.frombuffer(p_array, np.uint)
|
123 |
+
reshaped_array = np.reshape(buffer_as_numpy_array, (w, h, 4))
|
124 |
+
|
125 |
+
img = ax.imshow(
|
126 |
+
reshaped_array[:, :, 0:3],
|
127 |
+
cmap="gray",
|
128 |
+
vmin=0,
|
129 |
+
vmax=255,
|
130 |
+
origin="lower",
|
131 |
+
extent=[np.amin(X_axis), np.amax(X_axis), np.amax(Y_axis), np.amin(Y_axis)],
|
132 |
+
)
|
133 |
+
|
134 |
+
# starting copy from the original __int__
|
135 |
+
self.w = w
|
136 |
+
self.h = h
|
137 |
+
|
138 |
+
(
|
139 |
+
self.usgfw2,
|
140 |
+
self.p_array,
|
141 |
+
self.res_X,
|
142 |
+
self.res_Y,
|
143 |
+
self.old_resolution_x,
|
144 |
+
self.old_resolution_y,
|
145 |
+
self.X_axis,
|
146 |
+
self.Y_axis,
|
147 |
+
self.img,
|
148 |
+
) = (
|
149 |
+
usgfw2,
|
150 |
+
p_array,
|
151 |
+
res_X,
|
152 |
+
res_Y,
|
153 |
+
old_resolution_x,
|
154 |
+
old_resolution_y,
|
155 |
+
X_axis,
|
156 |
+
Y_axis,
|
157 |
+
img,
|
158 |
+
)
|
159 |
+
|
160 |
+
# return the image from telemed
|
161 |
+
def imaging(self):
|
162 |
+
self.usgfw2.return_pixel_values(ctypes.pointer(self.p_array))
|
163 |
+
buffer_as_numpy_array = np.frombuffer(self.p_array, np.uint)
|
164 |
+
reshaped_array = np.reshape(buffer_as_numpy_array, (self.w, self.h, 4))
|
165 |
+
|
166 |
+
self.usgfw2.get_resolution(
|
167 |
+
ctypes.pointer(self.res_X), ctypes.pointer(self.res_Y)
|
168 |
+
)
|
169 |
+
if (
|
170 |
+
self.res_X.value != self.old_resolution_x
|
171 |
+
or self.res_Y.value != self.old_resolution_y
|
172 |
+
):
|
173 |
+
if self.w % 2 == 0:
|
174 |
+
k = 0
|
175 |
+
for i in range(-self.w // 2, self.w // 2 + 1):
|
176 |
+
if i < 0:
|
177 |
+
j = i + 0.5
|
178 |
+
self.X_axis[k] = j * self.res_X.value
|
179 |
+
k = k + 1
|
180 |
+
else:
|
181 |
+
if i > 0:
|
182 |
+
j = i - 0.5
|
183 |
+
self.X_axis[k] = j * self.res_X.value
|
184 |
+
k = k + 1
|
185 |
+
else:
|
186 |
+
for i in range(-self.w // 2, self.w // 2):
|
187 |
+
self.X_axis[i + self.w / 2 + 1] = i * self.res_X.value
|
188 |
+
|
189 |
+
for i in range(0, self.h - 1):
|
190 |
+
self.Y_axis[i] = i * self.res_Y.value
|
191 |
+
|
192 |
+
self.old_resolution_x = self.res_X.value
|
193 |
+
self.old_resolution_y = self.res_X.value
|
194 |
+
|
195 |
+
self.img.set_data(reshaped_array[:, :, 0:3])
|
196 |
+
self.img.set_extent(
|
197 |
+
[
|
198 |
+
np.amin(self.X_axis),
|
199 |
+
np.amax(self.X_axis),
|
200 |
+
np.amax(self.Y_axis),
|
201 |
+
np.amin(self.Y_axis),
|
202 |
+
]
|
203 |
+
)
|
204 |
+
|
205 |
+
# Transfer image format to cv2
|
206 |
+
img_array = np.asarray(self.img.get_array())
|
207 |
+
img_array = img_array[::-1, :, ::-1] # format same as plt image, RBG to BGR
|
208 |
+
return img_array
|
209 |
+
|
210 |
+
|
211 |
+
class Thread(QThread):
|
212 |
+
updateFrame = Signal(QImage)
|
213 |
+
|
214 |
+
def __init__(self, parent=None, args=None):
|
215 |
+
QThread.__init__(self, parent)
|
216 |
+
self.status = True
|
217 |
+
self.cap = True
|
218 |
+
self.args = args
|
219 |
+
|
220 |
+
# init telemed
|
221 |
+
if args.video is None:
|
222 |
+
self.telemed = Telemed()
|
223 |
+
|
224 |
+
# init model
|
225 |
+
is_async = (
|
226 |
+
True if self.args.jobs == "auto" or int(self.args.jobs) > 1 else False
|
227 |
+
)
|
228 |
+
self.model = Model(
|
229 |
+
model_path=self.args.model,
|
230 |
+
imgsz=self.args.img_size,
|
231 |
+
classes=self.args.classes,
|
232 |
+
device=self.args.device,
|
233 |
+
plot_mask=self.args.plot_mask,
|
234 |
+
conf_thres=self.args.conf_thres,
|
235 |
+
is_async=is_async,
|
236 |
+
n_jobs=self.args.jobs,
|
237 |
+
)
|
238 |
+
|
239 |
+
def get_stats_fig(self, aorta_widths, aorta_confs, fig_w, fig_h, ts):
|
240 |
+
title_font_size = 28
|
241 |
+
body_font_size = 24
|
242 |
+
img_quality = 100 * np.mean(aorta_confs)
|
243 |
+
avg_width = np.mean(aorta_widths)
|
244 |
+
max_width = np.max(aorta_widths)
|
245 |
+
suggestions = [
|
246 |
+
"N/A, within normal limit",
|
247 |
+
"Follow up in 5 years",
|
248 |
+
"Make an appointment as soon as possible",
|
249 |
+
]
|
250 |
+
s = None
|
251 |
+
if avg_width < 3:
|
252 |
+
s = suggestions[0]
|
253 |
+
elif avg_width < 5:
|
254 |
+
s = suggestions[1]
|
255 |
+
else:
|
256 |
+
s = suggestions[2]
|
257 |
+
|
258 |
+
# region smoothing: method 2, keep the peaks
|
259 |
+
# peaks = signal.find_peaks(aorta_widths, height=0.5, distance=40)
|
260 |
+
# new_y = []
|
261 |
+
# # smooth the values between the peaks
|
262 |
+
# start = 0
|
263 |
+
# end = peaks[0][0]
|
264 |
+
# new_y.extend(signal.savgol_filter(aorta_widths[start:end], end - start, 2))
|
265 |
+
# for i in range(len(peaks[0]) - 1):
|
266 |
+
# start = peaks[0][i] + 1
|
267 |
+
# end = peaks[0][i + 1]
|
268 |
+
# new_y.append(aorta_widths[peaks[0][i]]) # add peak value
|
269 |
+
# new_y.extend(
|
270 |
+
# signal.savgol_filter(
|
271 |
+
# aorta_widths[start:end],
|
272 |
+
# end - start, # window size used for filtering
|
273 |
+
# 2,
|
274 |
+
# )
|
275 |
+
# ) # order of fitted polynomial
|
276 |
+
# # add the last peak
|
277 |
+
# new_y.append(aorta_widths[peaks[0][-1]])
|
278 |
+
# start = peaks[0][-1] + 1
|
279 |
+
# end = len(aorta_widths)
|
280 |
+
# new_y.extend(signal.savgol_filter(aorta_widths[start:end], end - start, 2))
|
281 |
+
# endregion
|
282 |
+
|
283 |
+
# region smoothing: method 1, do not keep the peaks
|
284 |
+
window_size = 53
|
285 |
+
if len(aorta_widths) < window_size:
|
286 |
+
window_size = len(aorta_widths) - 1
|
287 |
+
new_y = signal.savgol_filter(aorta_widths, window_size, 3)
|
288 |
+
# endregion
|
289 |
+
|
290 |
+
x = np.arange(1, len(aorta_widths) + 1, dtype=int)
|
291 |
+
|
292 |
+
fig = go.Figure()
|
293 |
+
fig.add_trace(
|
294 |
+
go.Scatter(
|
295 |
+
x=x, y=aorta_widths, mode="lines", line=dict(color="royalblue", width=1)
|
296 |
+
)
|
297 |
+
)
|
298 |
+
fig.add_trace(
|
299 |
+
go.Scatter(
|
300 |
+
x=x,
|
301 |
+
y=new_y,
|
302 |
+
mode="lines",
|
303 |
+
marker=dict(
|
304 |
+
size=3,
|
305 |
+
color="mediumpurple",
|
306 |
+
),
|
307 |
+
)
|
308 |
+
)
|
309 |
+
fig.update_layout(
|
310 |
+
autosize=False,
|
311 |
+
width=fig_w,
|
312 |
+
height=fig_h,
|
313 |
+
margin=dict(l=50, r=50, b=50, t=400, pad=4),
|
314 |
+
paper_bgcolor="LightSteelBlue",
|
315 |
+
showlegend=False,
|
316 |
+
)
|
317 |
+
fig.add_annotation(
|
318 |
+
text=f"max={max_width:.2f} cm",
|
319 |
+
x=np.argmax(aorta_widths),
|
320 |
+
y=np.max(aorta_widths),
|
321 |
+
xref="x",
|
322 |
+
yref="y",
|
323 |
+
showarrow=True,
|
324 |
+
font=dict(color="#ffffff"),
|
325 |
+
arrowhead=2,
|
326 |
+
arrowsize=1,
|
327 |
+
arrowwidth=2,
|
328 |
+
borderpad=4,
|
329 |
+
bgcolor="#ff7f0e",
|
330 |
+
opacity=0.8,
|
331 |
+
)
|
332 |
+
fig.add_annotation(
|
333 |
+
text=f"smoothed max={np.max(new_y):.2f} cm",
|
334 |
+
x=np.argmax(new_y),
|
335 |
+
y=np.max(new_y),
|
336 |
+
xref="x",
|
337 |
+
yref="y",
|
338 |
+
showarrow=True,
|
339 |
+
font=dict(color="#ffffff"),
|
340 |
+
arrowhead=2,
|
341 |
+
arrowsize=1,
|
342 |
+
arrowwidth=2,
|
343 |
+
ax=-100,
|
344 |
+
ay=-50,
|
345 |
+
borderpad=4,
|
346 |
+
bgcolor="#ff7f0e",
|
347 |
+
opacity=0.8,
|
348 |
+
)
|
349 |
+
fig.add_annotation(
|
350 |
+
text="<b>Report of Abdominal Aorta Examination</b>",
|
351 |
+
xref="paper",
|
352 |
+
yref="paper",
|
353 |
+
x=0.5,
|
354 |
+
y=2.3,
|
355 |
+
showarrow=False,
|
356 |
+
font=dict(size=title_font_size),
|
357 |
+
)
|
358 |
+
fig.add_annotation(
|
359 |
+
text=f"Image acquisition quality: {img_quality:.0f}%",
|
360 |
+
xref="paper",
|
361 |
+
yref="paper",
|
362 |
+
x=0,
|
363 |
+
y=2.0,
|
364 |
+
showarrow=False,
|
365 |
+
font=dict(size=body_font_size),
|
366 |
+
)
|
367 |
+
fig.add_annotation(
|
368 |
+
text=f"Aorta Maximal Width: {max_width:.2f} cm",
|
369 |
+
xref="paper",
|
370 |
+
yref="paper",
|
371 |
+
x=0,
|
372 |
+
y=1.8,
|
373 |
+
showarrow=False,
|
374 |
+
font=dict(size=body_font_size),
|
375 |
+
)
|
376 |
+
fig.add_annotation(
|
377 |
+
text=f"Aorta Maximal Width (Smoothed): {np.max(new_y):.2f} cm",
|
378 |
+
xref="paper",
|
379 |
+
yref="paper",
|
380 |
+
x=0,
|
381 |
+
y=1.6,
|
382 |
+
showarrow=False,
|
383 |
+
font=dict(size=body_font_size),
|
384 |
+
)
|
385 |
+
fig.add_annotation(
|
386 |
+
text=f"Average: {avg_width:.2f} cm",
|
387 |
+
xref="paper",
|
388 |
+
yref="paper",
|
389 |
+
x=0,
|
390 |
+
y=1.4,
|
391 |
+
showarrow=False,
|
392 |
+
font=dict(size=body_font_size),
|
393 |
+
)
|
394 |
+
fig.add_annotation(
|
395 |
+
text=f"Suggestion: {s}",
|
396 |
+
xref="paper",
|
397 |
+
yref="paper",
|
398 |
+
x=0,
|
399 |
+
y=1.2,
|
400 |
+
showarrow=False,
|
401 |
+
font=dict(size=body_font_size),
|
402 |
+
)
|
403 |
+
fig.add_annotation(
|
404 |
+
text=f"Generated at {ts}",
|
405 |
+
xref="paper",
|
406 |
+
yref="paper",
|
407 |
+
x=1,
|
408 |
+
y=1,
|
409 |
+
showarrow=False,
|
410 |
+
)
|
411 |
+
return fig
|
412 |
+
|
413 |
+
def run(self):
|
414 |
+
one_cm_in_pixels = 48 # hard-coded
|
415 |
+
aorta_cm_thre1 = 3
|
416 |
+
aorta_cm_thre2 = 5
|
417 |
+
black = (0, 0, 0)
|
418 |
+
white = (255, 255, 255)
|
419 |
+
red = (0, 0, 255)
|
420 |
+
green = (0, 255, 0)
|
421 |
+
|
422 |
+
aorta_widths_stats = [0, 0, 0] # three ranges: <3, 3-5, >5
|
423 |
+
aorta_widths = []
|
424 |
+
aorta_confs = []
|
425 |
+
|
426 |
+
expected_fps = None
|
427 |
+
frame_count = None
|
428 |
+
frame_w = None
|
429 |
+
frame_h = None
|
430 |
+
if self.args.video:
|
431 |
+
self.cap = cv2.VideoCapture(self.args.video)
|
432 |
+
expected_fps = self.cap.get(cv2.CAP_PROP_FPS)
|
433 |
+
secs_per_frame = 1 / expected_fps
|
434 |
+
frame_w, frame_h = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
|
435 |
+
self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
436 |
+
)
|
437 |
+
frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
438 |
+
logger.info(f"Video source FPS: {expected_fps}")
|
439 |
+
logger.info(f"Milliseconds per frame: {secs_per_frame}")
|
440 |
+
logger.info(f"Video source resolution (WxH): {frame_w}x{frame_h}")
|
441 |
+
logger.info(f"Video source frame count: {frame_count}")
|
442 |
+
assert frame_count > 0, "No frame found"
|
443 |
+
|
444 |
+
n_read_frames = 0
|
445 |
+
next_frame_to_infer = 0
|
446 |
+
next_frame_to_show = 0
|
447 |
+
n_repeat_failure = 0
|
448 |
+
is_last_failed = False
|
449 |
+
start_time = perf_counter()
|
450 |
+
while self.status:
|
451 |
+
frame = None
|
452 |
+
|
453 |
+
# avoid infinite loop
|
454 |
+
if n_repeat_failure > 30:
|
455 |
+
break
|
456 |
+
|
457 |
+
# inference
|
458 |
+
color_frame, others, results, xyxy, conf = None, None, None, None, None
|
459 |
+
if self.model.is_async:
|
460 |
+
results = self.model.get_result(next_frame_to_show)
|
461 |
+
if results:
|
462 |
+
color_frame, others = results
|
463 |
+
xyxy, conf, _ = others
|
464 |
+
next_frame_to_show += 1
|
465 |
+
|
466 |
+
if self.model.is_async and self.model.is_free_to_infer_async():
|
467 |
+
if self.args.video:
|
468 |
+
ret, frame = self.cap.read()
|
469 |
+
|
470 |
+
if not ret:
|
471 |
+
n_repeat_failure += 1 if is_last_failed else 0
|
472 |
+
is_last_failed = True
|
473 |
+
continue
|
474 |
+
else:
|
475 |
+
# read the frame from telemed
|
476 |
+
# TODO(martin): Check read failure
|
477 |
+
frame = self.telemed.imaging()
|
478 |
+
|
479 |
+
n_read_frames += 1
|
480 |
+
self.model.predict_async(frame, next_frame_to_infer)
|
481 |
+
next_frame_to_infer += 1
|
482 |
+
elif not self.model.is_async:
|
483 |
+
if self.args.video:
|
484 |
+
ret, frame = self.cap.read()
|
485 |
+
if not ret:
|
486 |
+
n_repeat_failure += 1 if is_last_failed else 0
|
487 |
+
is_last_failed = True
|
488 |
+
continue
|
489 |
+
else:
|
490 |
+
# read the frame from telemed
|
491 |
+
# TODO(martin): Check read failure
|
492 |
+
frame = self.telemed.imaging()
|
493 |
+
|
494 |
+
n_read_frames += 1
|
495 |
+
results = self.model.predict(frame)
|
496 |
+
color_frame, others = results
|
497 |
+
xyxy, conf, _ = others # bbox and confidence
|
498 |
+
if results is None:
|
499 |
+
continue
|
500 |
+
|
501 |
+
is_last_failed = False
|
502 |
+
|
503 |
+
# check if aorta is within the ROI box, and draw the box
|
504 |
+
aorta_width_in_cm = 0
|
505 |
+
is_found = xyxy is not None
|
506 |
+
is_in_box = False
|
507 |
+
is_too_left, is_too_right = False, False
|
508 |
+
w, h = color_frame.shape[1], color_frame.shape[0]
|
509 |
+
box_w = int(w * 0.1)
|
510 |
+
box_h = int(h * 0.5)
|
511 |
+
box_top_left = (w // 2 - box_w // 2, h // 4)
|
512 |
+
box_bottom_right = (w // 2 + box_w // 2, h // 4 + box_h)
|
513 |
+
if xyxy is not None:
|
514 |
+
x1, y1, x2, y2 = xyxy
|
515 |
+
|
516 |
+
# check aorta width
|
517 |
+
aorta_width_in_cm = (x2 - x1) / one_cm_in_pixels
|
518 |
+
aorta_widths.append(aorta_width_in_cm)
|
519 |
+
aorta_confs.append(conf)
|
520 |
+
if aorta_width_in_cm < aorta_cm_thre1:
|
521 |
+
aorta_widths_stats[0] += 1
|
522 |
+
elif aorta_width_in_cm < aorta_cm_thre2:
|
523 |
+
aorta_widths_stats[1] += 1
|
524 |
+
else:
|
525 |
+
aorta_widths_stats[2] += 1
|
526 |
+
|
527 |
+
# check whether aorta is in the box
|
528 |
+
if (
|
529 |
+
x1 > box_top_left[0]
|
530 |
+
and x2 < box_bottom_right[0]
|
531 |
+
and y1 > box_top_left[1]
|
532 |
+
and y2 < box_bottom_right[1]
|
533 |
+
):
|
534 |
+
is_in_box = True
|
535 |
+
is_too_right = x2 > box_bottom_right[0]
|
536 |
+
is_too_left = x1 < box_top_left[0]
|
537 |
+
|
538 |
+
# plot ROI box with color status
|
539 |
+
box_color = green if is_in_box else red
|
540 |
+
color_frame = cv2.rectangle(
|
541 |
+
color_frame, box_top_left, box_bottom_right, box_color, 2
|
542 |
+
)
|
543 |
+
assert not (
|
544 |
+
is_too_left and is_too_right
|
545 |
+
), "Cannot be both too left and too right"
|
546 |
+
if is_too_left:
|
547 |
+
start_p = (box_top_left[0], int(h * 0.9))
|
548 |
+
end_p = (box_bottom_right[0], int(h * 0.9))
|
549 |
+
cv2.arrowedLine(color_frame, start_p, end_p, red, 3)
|
550 |
+
if is_too_right:
|
551 |
+
start_p = (box_bottom_right[0], int(h * 0.9))
|
552 |
+
end_p = (box_top_left[0], int(h * 0.9))
|
553 |
+
cv2.arrowedLine(color_frame, start_p, end_p, red, 3)
|
554 |
+
if is_in_box:
|
555 |
+
cv2.putText(
|
556 |
+
color_frame,
|
557 |
+
"GOOD",
|
558 |
+
(box_top_left[0], int(h * 0.9)),
|
559 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
560 |
+
1,
|
561 |
+
green,
|
562 |
+
3,
|
563 |
+
)
|
564 |
+
|
565 |
+
# plot aorta width
|
566 |
+
text = (
|
567 |
+
f"Aorta width: {aorta_width_in_cm:.2f} cm"
|
568 |
+
if is_found
|
569 |
+
else "Aorta width: N/A"
|
570 |
+
)
|
571 |
+
cv2.putText(
|
572 |
+
color_frame, text, (50, 90), cv2.FONT_HERSHEY_SIMPLEX, 1, white, 3
|
573 |
+
)
|
574 |
+
|
575 |
+
# region FPS
|
576 |
+
fps = None
|
577 |
+
if n_read_frames > 0:
|
578 |
+
fps = n_read_frames / (perf_counter() - start_time)
|
579 |
+
|
580 |
+
# Slow down the loop if FPS is too high
|
581 |
+
if self.args.sync:
|
582 |
+
while fps > expected_fps:
|
583 |
+
time.sleep(0.001)
|
584 |
+
fps = n_read_frames / (perf_counter() - start_time)
|
585 |
+
|
586 |
+
cv2.putText(
|
587 |
+
color_frame,
|
588 |
+
f"FPS: {fps:.2f}",
|
589 |
+
(50, 30),
|
590 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
591 |
+
1,
|
592 |
+
white,
|
593 |
+
3,
|
594 |
+
)
|
595 |
+
# endregion
|
596 |
+
|
597 |
+
# Creating and scaling QImage
|
598 |
+
h, w, ch = color_frame.shape
|
599 |
+
img = QImage(color_frame.data, w, h, ch * w, QImage.Format_BGR888)
|
600 |
+
scaled_img = img.scaled(video_w, video_h, Qt.KeepAspectRatio)
|
601 |
+
|
602 |
+
# Emit signal
|
603 |
+
self.updateFrame.emit(scaled_img)
|
604 |
+
|
605 |
+
if self.args.video:
|
606 |
+
progress = 100 * n_read_frames / frame_count
|
607 |
+
fps_msg = f", FPS: {fps:.2f}" if fps is not None else ""
|
608 |
+
print(
|
609 |
+
f"Processed {n_read_frames}/{frame_count} ({progress:.2f}%) frames"
|
610 |
+
+ fps_msg,
|
611 |
+
end="\r" if n_read_frames < frame_count else os.linesep,
|
612 |
+
)
|
613 |
+
if n_read_frames >= frame_count:
|
614 |
+
logger.info("Finished processing video")
|
615 |
+
break
|
616 |
+
if self.args.video:
|
617 |
+
self.cap.release()
|
618 |
+
|
619 |
+
if not self.status:
|
620 |
+
logger.info("Stopped by user")
|
621 |
+
return
|
622 |
+
|
623 |
+
# draw a black image with frame width & height
|
624 |
+
# with some text in center indicating generating report
|
625 |
+
# it's just a dummy step to make demo more real
|
626 |
+
im = np.zeros((frame_h, frame_w, 3), np.uint8)
|
627 |
+
cv2.putText(
|
628 |
+
im,
|
629 |
+
"Generating report for you...",
|
630 |
+
(frame_w // 3, frame_h // 2),
|
631 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
632 |
+
1,
|
633 |
+
white,
|
634 |
+
3,
|
635 |
+
)
|
636 |
+
img = QImage(im.data, frame_w, frame_h, ch * w, QImage.Format_BGR888)
|
637 |
+
scaled_img = img.scaled(video_w, video_h, Qt.KeepAspectRatio)
|
638 |
+
self.updateFrame.emit(scaled_img)
|
639 |
+
time.sleep(3)
|
640 |
+
|
641 |
+
# plot aorta width tracing line chart
|
642 |
+
now_t = datetime.now()
|
643 |
+
ts1 = now_t.strftime("%Y%m%d_%H%M%S")
|
644 |
+
ts2 = now_t.strftime("%Y/%m/%d %I:%M:%S")
|
645 |
+
Path("runs").mkdir(parents=True, exist_ok=True)
|
646 |
+
# np.save("runs/aorta_widths.npy", aorta_widths)
|
647 |
+
fig_out_p = f"runs/aorta_report_{ts1}.jpeg"
|
648 |
+
fig = self.get_stats_fig(aorta_widths, aorta_confs, video_w, video_h, ts2)
|
649 |
+
|
650 |
+
# This may hang under Windows: https://github.com/plotly/Kaleido/issues/110
|
651 |
+
# The workaround is to install older kaleido version (see requirements.txt)
|
652 |
+
fig.write_image(fig_out_p)
|
653 |
+
|
654 |
+
logger.info(f"Saved aorta report: {fig_out_p}")
|
655 |
+
img_bytes = fig.to_image(format="jpg", width=video_w, height=video_h)
|
656 |
+
line_chart = np.array(Image.open(io.BytesIO(img_bytes)))
|
657 |
+
line_chart = cv2.cvtColor(line_chart, cv2.COLOR_RGB2BGR)
|
658 |
+
h, w, ch = line_chart.shape
|
659 |
+
img = QImage(line_chart.data, video_w, video_h, ch * w, QImage.Format_BGR888)
|
660 |
+
scaled_img = img.scaled(w, h, Qt.KeepAspectRatio)
|
661 |
+
# Emit signal
|
662 |
+
self.updateFrame.emit(scaled_img)
|
663 |
+
time.sleep(5)
|
664 |
+
|
665 |
+
# keep report open until user closes the window
|
666 |
+
while self.status and not self.args.exit_on_end:
|
667 |
+
time.sleep(0.1)
|
668 |
+
|
669 |
+
|
670 |
+
class Window(QMainWindow):
|
671 |
+
def __init__(self, args=None):
|
672 |
+
super().__init__()
|
673 |
+
# Title and dimensions
|
674 |
+
self.setWindowTitle("Demo")
|
675 |
+
self.setGeometry(0, 0, 800, 500)
|
676 |
+
|
677 |
+
# Create a label for the display camera
|
678 |
+
self.label = QLabel(self)
|
679 |
+
# self.label.setFixedSize(self.width(), self.height())
|
680 |
+
self.label.setFixedSize(video_w, video_h)
|
681 |
+
|
682 |
+
# Thread in charge of updating the image
|
683 |
+
self.th = Thread(self, args)
|
684 |
+
self.th.finished.connect(self.close)
|
685 |
+
self.th.updateFrame.connect(self.setImage)
|
686 |
+
|
687 |
+
# Buttons layout
|
688 |
+
buttons_layout = QHBoxLayout()
|
689 |
+
self.button1 = QPushButton("Start")
|
690 |
+
self.button2 = QPushButton("Stop/Close")
|
691 |
+
self.button1.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
|
692 |
+
self.button2.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding)
|
693 |
+
buttons_layout.addWidget(self.button2)
|
694 |
+
buttons_layout.addWidget(self.button1)
|
695 |
+
|
696 |
+
right_layout = QHBoxLayout()
|
697 |
+
# right_layout.addWidget(self.group_model, 1)
|
698 |
+
right_layout.addLayout(buttons_layout, 1)
|
699 |
+
|
700 |
+
# Main layout
|
701 |
+
layout = QVBoxLayout()
|
702 |
+
layout.addWidget(self.label)
|
703 |
+
layout.addLayout(right_layout)
|
704 |
+
|
705 |
+
# Central widget
|
706 |
+
widget = QWidget(self)
|
707 |
+
widget.setLayout(layout)
|
708 |
+
self.setCentralWidget(widget)
|
709 |
+
|
710 |
+
# Connections
|
711 |
+
self.button1.clicked.connect(self.start)
|
712 |
+
self.button2.clicked.connect(self.kill_thread)
|
713 |
+
self.button2.setEnabled(False)
|
714 |
+
|
715 |
+
if args.start_on_open:
|
716 |
+
# start thread
|
717 |
+
self.start()
|
718 |
+
|
719 |
+
@Slot()
|
720 |
+
def kill_thread(self):
|
721 |
+
logger.info("Finishing...")
|
722 |
+
self.th.status = False
|
723 |
+
time.sleep(1)
|
724 |
+
# Give time for the thread to finish
|
725 |
+
self.button2.setEnabled(False)
|
726 |
+
self.button1.setEnabled(True)
|
727 |
+
cv2.destroyAllWindows()
|
728 |
+
self.th.exit()
|
729 |
+
# Give time for the thread to finish
|
730 |
+
time.sleep(1)
|
731 |
+
|
732 |
+
@Slot()
|
733 |
+
def start(self):
|
734 |
+
logger.info("Starting...")
|
735 |
+
self.button2.setEnabled(True)
|
736 |
+
self.button1.setEnabled(False)
|
737 |
+
self.th.start()
|
738 |
+
logger.info("Thread started")
|
739 |
+
|
740 |
+
@Slot(QImage)
|
741 |
+
def setImage(self, image):
|
742 |
+
self.label.setPixmap(QPixmap.fromImage(image))
|
743 |
+
|
744 |
+
|
745 |
+
if __name__ == "__main__":
|
746 |
+
# get user inputs using argparse
|
747 |
+
parser = argparse.ArgumentParser()
|
748 |
+
parser.add_argument(
|
749 |
+
"--video",
|
750 |
+
type=str,
|
751 |
+
default=None,
|
752 |
+
help="path to video file, if None (default) would read from telemed",
|
753 |
+
)
|
754 |
+
parser.add_argument(
|
755 |
+
"--model",
|
756 |
+
type=str,
|
757 |
+
default="best_openvino_model/best.xml",
|
758 |
+
help="path to model file",
|
759 |
+
)
|
760 |
+
parser.add_argument("--img-size", type=int, default=640, help="image size")
|
761 |
+
parser.add_argument(
|
762 |
+
"--classes", nargs="+", type=int, default=[0], help="filter by class"
|
763 |
+
)
|
764 |
+
parser.add_argument("--device", type=str, default="CPU", help="device to use")
|
765 |
+
parser.add_argument("--sync", action="store_true", help="sync video FPS")
|
766 |
+
parser.add_argument("--plot-mask", action="store_true", help="plot mask")
|
767 |
+
parser.add_argument("--conf-thres", type=float, default=0.25, help="conf thresh")
|
768 |
+
parser.add_argument("--jobs", type=str, default=1, help="num of jobs, async if > 1")
|
769 |
+
parser.add_argument("--start-on-open", action="store_true", help="start on open")
|
770 |
+
parser.add_argument("--exit-on-end", action="store_true", help="exit if video ends")
|
771 |
+
args = parser.parse_args()
|
772 |
+
assert (
|
773 |
+
args.jobs == "auto" or int(args.jobs) > 0
|
774 |
+
), f"--jobs must be > 0 or auto, got {args.jobs}"
|
775 |
+
if args.video:
|
776 |
+
assert Path(args.video).exists(), f"Video file {args.video} not found"
|
777 |
+
assert Path(args.model).exists(), f"Model file {args.model} not found"
|
778 |
+
app = QApplication()
|
779 |
+
w = Window(args)
|
780 |
+
w.show()
|
781 |
+
sys.exit(app.exec())
|
demo_headless.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
VIRTUAL_DISPLAY_NUM=99
|
4 |
+
|
5 |
+
OUTPUT_VIDEO="runs/demo_recording_$(date +"%Y-%m-%d_%H-%M-%S").mp4"
|
6 |
+
|
7 |
+
# start xvfb server
|
8 |
+
Xvfb :$VIRTUAL_DISPLAY_NUM -screen 0 1280x720x24 > /dev/null & XVFB_PID=$!
|
9 |
+
|
10 |
+
# start recording
|
11 |
+
ffmpeg -f x11grab -draw_mouse 0 -video_size 1280x720 \
|
12 |
+
-i :$VIRTUAL_DISPLAY_NUM \
|
13 |
+
-codec:v libx264 -r 25 $OUTPUT_VIDEO \
|
14 |
+
> /dev/null 2>&1 < /dev/null & FFMPEG_PID=$!
|
15 |
+
|
16 |
+
# start the demo program
|
17 |
+
DISPLAY=:$VIRTUAL_DISPLAY_NUM QT_QPA_PLATFORM=xcb \
|
18 |
+
python demo.py "$@" --start-on-open --exit-on-end
|
19 |
+
|
20 |
+
# kill the recording
|
21 |
+
kill $FFMPEG_PID
|
22 |
+
|
23 |
+
# kill xvfb server
|
24 |
+
kill $XVFB_PID
|
25 |
+
|
26 |
+
# success msg
|
27 |
+
echo -e "Recording saved: $OUTPUT_VIDEO"
|
eval.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typer
|
2 |
+
from typing import Optional
|
3 |
+
from pathlib import Path
|
4 |
+
from loguru import logger
|
5 |
+
import cv2
|
6 |
+
from tqdm import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import shutil
|
10 |
+
from datetime import datetime
|
11 |
+
import matplotlib
|
12 |
+
import os
|
13 |
+
|
14 |
+
matplotlib.use("Agg") # use non-interactive backend
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
from predict import Model
|
18 |
+
|
19 |
+
app = typer.Typer()
|
20 |
+
|
21 |
+
|
22 |
+
@app.command(help="Export videos to images (to a dir per video)")
|
23 |
+
def export_videos_to_images(
|
24 |
+
input_dir: Path = typer.Argument(..., help="Input directory"),
|
25 |
+
output_dir: Path = typer.Argument(..., help="Output directory"),
|
26 |
+
ext: str = typer.Option("avi", help="Video Extension"),
|
27 |
+
path_filter: Optional[str] = typer.Option(None, help="input path filter"),
|
28 |
+
patient_prefix: Optional[bool] = typer.Option(
|
29 |
+
True, help="use patient info as output dir prefix"
|
30 |
+
),
|
31 |
+
copy_extent: Optional[bool] = typer.Option(
|
32 |
+
True, help="copy extent files to output dir"
|
33 |
+
),
|
34 |
+
):
|
35 |
+
# log all the arguments passed in
|
36 |
+
logger.info(f"Function called with arguments: {locals()}")
|
37 |
+
|
38 |
+
# find all video files in input_dir
|
39 |
+
input_dir = Path(input_dir)
|
40 |
+
output_dir = Path(output_dir)
|
41 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
42 |
+
video_files = list(input_dir.glob(f"**/*.{ext.lower()}"))
|
43 |
+
video_files.extend(list(input_dir.glob(f"**/*.{ext.upper()}")))
|
44 |
+
logger.info(f"# of avi videos found: {len(video_files)}")
|
45 |
+
if path_filter is not None:
|
46 |
+
logger.info(f"Filtering videos with {path_filter}")
|
47 |
+
video_files = [x for x in video_files if path_filter in str(x)]
|
48 |
+
logger.info(f"# of avi videos found after filtering: {len(video_files)}")
|
49 |
+
|
50 |
+
video_files.sort(key=lambda x: x.name) # sort by name ascending
|
51 |
+
# log each video path after filtering, one per line
|
52 |
+
logger.info(f"{os.linesep}" + f"{os.linesep}".join([str(x) for x in video_files]))
|
53 |
+
|
54 |
+
# check that all the extent files exist
|
55 |
+
# the extent (.csv) should be in the same directory as the video
|
56 |
+
# the video filename would start with video_
|
57 |
+
# the extent filename would start with extents_
|
58 |
+
if copy_extent:
|
59 |
+
all_exist = True
|
60 |
+
for video_path in video_files:
|
61 |
+
extent_filename = video_path.stem.replace("video_", "extents_")
|
62 |
+
extent_path = video_path.parent / f"{extent_filename}.csv"
|
63 |
+
if not extent_path.exists():
|
64 |
+
logger.error(f"Extent file {extent_path} does not exist")
|
65 |
+
all_exist = False
|
66 |
+
if not all_exist:
|
67 |
+
logger.error("Extent files do not exist for all videos")
|
68 |
+
return
|
69 |
+
|
70 |
+
for video_path in video_files:
|
71 |
+
# copy extent file to output dir
|
72 |
+
if copy_extent:
|
73 |
+
extent_filename = video_path.stem.replace("video_", "extents_")
|
74 |
+
extent_path = video_path.parent / f"{extent_filename}.csv"
|
75 |
+
shutil.copy(extent_path, output_dir)
|
76 |
+
|
77 |
+
# Dir structure: Patient_Info / [PATIENT_ID] / [DATE] / video / xxx.avi
|
78 |
+
patient_id = (
|
79 |
+
video_path.parent.parent.parent.name
|
80 |
+
) # WARNING: Hard-coded based on dir structure
|
81 |
+
|
82 |
+
video_name = video_path.stem
|
83 |
+
logger.info(f"Processing video {video_name} of patient {patient_id}")
|
84 |
+
|
85 |
+
# create subdirectory for each video
|
86 |
+
sub_dir = output_dir / (
|
87 |
+
f"{patient_id}-{video_name}" if patient_prefix else video_name
|
88 |
+
)
|
89 |
+
sub_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
|
91 |
+
# read video and export frames
|
92 |
+
cap = cv2.VideoCapture(str(video_path))
|
93 |
+
frame_count = 0
|
94 |
+
while cap.isOpened():
|
95 |
+
ret, frame = cap.read()
|
96 |
+
if ret:
|
97 |
+
# padding frame_count with zeros
|
98 |
+
cv2.imwrite(str(sub_dir / f"{frame_count:03}.jpg"), frame)
|
99 |
+
frame_count += 1
|
100 |
+
else:
|
101 |
+
break
|
102 |
+
|
103 |
+
|
104 |
+
@app.command(help="Evaluate model on a directory of images")
|
105 |
+
def eval(
|
106 |
+
input_dir: Path = typer.Argument(..., help="Input directory"),
|
107 |
+
input_model: Path = typer.Argument(..., help="Input model"),
|
108 |
+
imgsz: int = typer.Option(640, help="Image size"),
|
109 |
+
class_id: int = typer.Option(0, help="Class id to filter"),
|
110 |
+
conf_thresh: float = typer.Option(0.5, help="Confidence threshold"),
|
111 |
+
video_ext: str = typer.Option("avi", help="Video Extension"),
|
112 |
+
out_dir: Path = typer.Option("runs", help="Output directory"),
|
113 |
+
gt_csv_path: Path = typer.Option(
|
114 |
+
"results_20230822_aorta_identified_added_by_Ray.csv",
|
115 |
+
help="Ground truth csv path",
|
116 |
+
),
|
117 |
+
no_extent: Optional[bool] = typer.Option(True, help="no extent file"),
|
118 |
+
write_viz: Optional[bool] = typer.Option(False, help="write viz images"),
|
119 |
+
gt_column_name: str = typer.Option("aorta_identified", help="Ground truth column"),
|
120 |
+
):
|
121 |
+
# check inputs are valid
|
122 |
+
assert input_dir.exists(), f"Input directory {input_dir} does not exist"
|
123 |
+
assert input_model.exists(), f"Input model {input_model} does not exist"
|
124 |
+
assert gt_csv_path.exists(), f"Ground truth csv {gt_csv_path} does not exist"
|
125 |
+
|
126 |
+
# load model
|
127 |
+
model = Model(
|
128 |
+
model_path=str(input_model),
|
129 |
+
imgsz=imgsz,
|
130 |
+
classes=[class_id], # filter by class id, only aorta
|
131 |
+
device="CPU",
|
132 |
+
plot_mask=True,
|
133 |
+
conf_thres=conf_thresh,
|
134 |
+
is_async=False,
|
135 |
+
n_jobs=1,
|
136 |
+
)
|
137 |
+
|
138 |
+
# setup output directory
|
139 |
+
out_dir = Path(out_dir)
|
140 |
+
# create a sub output directory of current date and time
|
141 |
+
start_t = datetime.now()
|
142 |
+
start_timestamp = start_t.strftime("%Y_%m_%d_%H_%M_%S")
|
143 |
+
out_dir = out_dir / f"max_aorta_result-{start_timestamp}"
|
144 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
145 |
+
|
146 |
+
# log to file
|
147 |
+
logger.add(str(out_dir.absolute()) + "/eval_{time}.log")
|
148 |
+
|
149 |
+
out_csv_p = out_dir / "results.csv"
|
150 |
+
out_trace_csv_p = out_dir / "trace.csv"
|
151 |
+
logger.info(f"Output directory: {out_dir}")
|
152 |
+
|
153 |
+
# find all directories in input_dir
|
154 |
+
input_dir = Path(input_dir)
|
155 |
+
sub_dirs = [x for x in input_dir.iterdir() if x.is_dir()]
|
156 |
+
sub_dirs.sort(key=lambda x: x.name) # sort sub_dirs by name ascending
|
157 |
+
logger.info(f"# of subdirectories found: {len(sub_dirs)}")
|
158 |
+
num_sub_dirs = len(sub_dirs)
|
159 |
+
has_patient_prefix = False if sub_dirs[0].name.startswith("video") else True
|
160 |
+
|
161 |
+
# setup csv headers
|
162 |
+
trace_headers = ["video", "image_idx", "aorta_pixels", "aorta_mm", "conf"]
|
163 |
+
headers = ["video", "max_aorta_pixels", "max_aorta_mm", "max_image_idx", "conf"]
|
164 |
+
if has_patient_prefix:
|
165 |
+
headers.insert(0, "patient_info")
|
166 |
+
|
167 |
+
# loop through each subdirectory of images
|
168 |
+
for idx, sub_dir in enumerate(sub_dirs):
|
169 |
+
max_aorta_w = -1 # max aorta width in pixels
|
170 |
+
max_aorta_w_mm = -1 # max aorta width in mm
|
171 |
+
max_aorta_viz = None
|
172 |
+
max_aorta_im_path = None
|
173 |
+
max_center_x, max_center_y = -1, -1
|
174 |
+
max_conf = None
|
175 |
+
max_im_n = ""
|
176 |
+
|
177 |
+
# read the extent file of the images
|
178 |
+
# the extent file should be in the same directory as the video
|
179 |
+
video_filename = (
|
180 |
+
sub_dir.name
|
181 |
+
if not has_patient_prefix
|
182 |
+
else "-".join(sub_dir.name.split("-")[1:])
|
183 |
+
)
|
184 |
+
extent_filename = video_filename.replace("video_", "extents_")
|
185 |
+
extent_file = sub_dir.parent / f"{extent_filename}.csv"
|
186 |
+
extents = None
|
187 |
+
if not no_extent:
|
188 |
+
assert extent_file.exists(), f"Extent file {extent_file} does not exist"
|
189 |
+
extents = pd.read_csv(extent_file).to_dict("records")
|
190 |
+
|
191 |
+
logger.info(f"Processing subdir {sub_dir.name} ({idx+1}/{num_sub_dirs})")
|
192 |
+
# find all images in sub_dir
|
193 |
+
images = list(sub_dir.glob("*.jpg"))
|
194 |
+
# Sort the list of images in ascending order by name
|
195 |
+
images.sort(key=lambda img: img.name)
|
196 |
+
logger.info(f"\t# of images found: {len(images)}")
|
197 |
+
|
198 |
+
# create a viz output directory for each sub_dir
|
199 |
+
out_sub_viz_dir = out_dir / sub_dir.name
|
200 |
+
Path(out_sub_viz_dir).mkdir(parents=True, exist_ok=True)
|
201 |
+
|
202 |
+
for im_idx, image_path in enumerate(tqdm(images)):
|
203 |
+
# read image
|
204 |
+
cv_frame = cv2.imread(str(image_path))
|
205 |
+
cv_width = cv_frame.shape[1]
|
206 |
+
|
207 |
+
# inference
|
208 |
+
viz_frame, results = model.predict(cv_frame)
|
209 |
+
bbox_xyxy = results[0]
|
210 |
+
conf = results[1]
|
211 |
+
masks = results[2]
|
212 |
+
|
213 |
+
# output viz image if the flag is set
|
214 |
+
if write_viz:
|
215 |
+
cv2.imwrite(
|
216 |
+
str(out_sub_viz_dir / f"{image_path.stem}_viz.jpg"),
|
217 |
+
viz_frame,
|
218 |
+
)
|
219 |
+
|
220 |
+
trace_row = [
|
221 |
+
f"{sub_dir.name}.{video_ext}",
|
222 |
+
image_path.stem,
|
223 |
+
-1,
|
224 |
+
-1,
|
225 |
+
conf,
|
226 |
+
]
|
227 |
+
|
228 |
+
if masks is not None or bbox_xyxy is not None:
|
229 |
+
# method 1: find the largest contour
|
230 |
+
# find min enclosing circle of mask
|
231 |
+
# mask = (masks * 255).astype(np.uint8)
|
232 |
+
# contours, _ = cv2.findContours(
|
233 |
+
# mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
234 |
+
# )
|
235 |
+
# largest_contour = max(contours, key=cv2.contourArea)
|
236 |
+
# (center_x, center_y), radius = cv2.minEnclosingCircle(largest_contour)
|
237 |
+
# aorta_width = radius * 2
|
238 |
+
|
239 |
+
# method 2: use the height of the bbox as a measure of aorta width
|
240 |
+
# because we observed that the width of the bbox is too large
|
241 |
+
aorta_width = bbox_xyxy[3] - bbox_xyxy[1]
|
242 |
+
|
243 |
+
# get physical unit
|
244 |
+
w_mm_left, w_mm_right, w_mm_per_pixel = None, None, None
|
245 |
+
if not no_extent:
|
246 |
+
w_mm_left = extents[im_idx]["Width-Left(mm)"]
|
247 |
+
w_mm_right = extents[im_idx]["Width-Right(mm)"]
|
248 |
+
assert w_mm_right > 0 and w_mm_left < 0
|
249 |
+
w_mm_per_pixel = (w_mm_right - w_mm_left) / cv_width
|
250 |
+
|
251 |
+
# update trace when aorta is found
|
252 |
+
trace_row[2] = aorta_width
|
253 |
+
trace_row[3] = aorta_width * w_mm_per_pixel if not no_extent else None
|
254 |
+
|
255 |
+
# output viz image when aorta is found
|
256 |
+
cv2.imwrite(
|
257 |
+
str(out_sub_viz_dir / f"{image_path.stem}_viz.jpg"),
|
258 |
+
viz_frame,
|
259 |
+
)
|
260 |
+
# copy the raw image to the output directory
|
261 |
+
shutil.copy(image_path, out_sub_viz_dir)
|
262 |
+
|
263 |
+
if aorta_width > max_aorta_w:
|
264 |
+
max_aorta_w = aorta_width
|
265 |
+
max_aorta_viz = viz_frame.copy()
|
266 |
+
max_aorta_im_path = image_path
|
267 |
+
|
268 |
+
# Note: only need to calculate the center if using method 1
|
269 |
+
# max_center_x = center_x
|
270 |
+
# max_center_y = center_y
|
271 |
+
|
272 |
+
max_im_n = image_path.stem
|
273 |
+
max_conf = conf
|
274 |
+
logger.info(
|
275 |
+
f"\tNew max aorta (pixels): {max_aorta_w:.2f}, conf: {max_conf:.2f}"
|
276 |
+
)
|
277 |
+
|
278 |
+
# convert pixels to mm
|
279 |
+
max_aorta_w_mm = (
|
280 |
+
max_aorta_w * w_mm_per_pixel if not no_extent else None
|
281 |
+
)
|
282 |
+
|
283 |
+
# save trace to csv
|
284 |
+
df = pd.DataFrame([trace_row], columns=trace_headers)
|
285 |
+
df.to_csv(
|
286 |
+
out_trace_csv_p,
|
287 |
+
mode="a",
|
288 |
+
header=not out_trace_csv_p.exists(),
|
289 |
+
index=False,
|
290 |
+
float_format="%.3f",
|
291 |
+
)
|
292 |
+
|
293 |
+
if max_aorta_w > 0:
|
294 |
+
logger.info(f"\tMax aorta (pixels): {max_aorta_w:.2f}")
|
295 |
+
# copy the raw image to the output directory
|
296 |
+
out_raw_p = out_dir / f"raw_{sub_dir.name}_{max_im_n}.jpg"
|
297 |
+
shutil.copy(max_aorta_im_path, out_raw_p)
|
298 |
+
|
299 |
+
# method 1 viz: draw enclosing circle on max_aorta_viz
|
300 |
+
# plot circle on max_aorta_viz
|
301 |
+
# cv2.circle(
|
302 |
+
# max_aorta_viz,
|
303 |
+
# (int(max_center_x), int(max_center_y)),
|
304 |
+
# int(max_aorta_w / 2),
|
305 |
+
# (0, 255, 0),
|
306 |
+
# 2,
|
307 |
+
# )
|
308 |
+
|
309 |
+
# region Save the image with extent
|
310 |
+
# convert the BGR image to RGB image
|
311 |
+
out_viz_p = out_dir / f"viz_{sub_dir.name}_{max_im_n}.jpg"
|
312 |
+
max_aorta_viz_rgb = cv2.cvtColor(max_aorta_viz, cv2.COLOR_BGR2RGB)
|
313 |
+
# Use matplotlib to save the image
|
314 |
+
# Get the size of the image in inches
|
315 |
+
dpi = plt.rcParams["figure.dpi"] # Get the default dpi value
|
316 |
+
figsize = (
|
317 |
+
max_aorta_viz_rgb.shape[1] / dpi,
|
318 |
+
max_aorta_viz_rgb.shape[0] / dpi,
|
319 |
+
) # width, height
|
320 |
+
# Create a new figure with the same aspect ratio as the image
|
321 |
+
fig = plt.figure(figsize=figsize)
|
322 |
+
if not no_extent:
|
323 |
+
# specify the extent of the image in the form [xmin, xmax, ymin, ymax]
|
324 |
+
extent = [
|
325 |
+
extents[im_idx]["Width-Left(mm)"],
|
326 |
+
extents[im_idx]["Width-Right(mm)"],
|
327 |
+
extents[im_idx]["Depth-Bottom(mm)"],
|
328 |
+
extents[im_idx]["Depth-Top(mm)"],
|
329 |
+
]
|
330 |
+
plt.imshow(max_aorta_viz_rgb, extent=extent)
|
331 |
+
plt.xlabel("Width [mm]")
|
332 |
+
plt.ylabel("Depth [mm]")
|
333 |
+
else:
|
334 |
+
plt.imshow(max_aorta_viz_rgb)
|
335 |
+
plt.savefig(str(out_viz_p))
|
336 |
+
plt.close(fig)
|
337 |
+
# cv2.imwrite(str(out_viz_p), max_aorta_viz)
|
338 |
+
# endregion
|
339 |
+
else:
|
340 |
+
logger.warning(f"\tNo aorta found in {sub_dir.name}")
|
341 |
+
patient_info = sub_dir.name.split("-")[0] if has_patient_prefix else ""
|
342 |
+
row = [
|
343 |
+
f"{sub_dir.name}.{video_ext}",
|
344 |
+
max_aorta_w,
|
345 |
+
max_aorta_w_mm,
|
346 |
+
max_im_n,
|
347 |
+
max_conf,
|
348 |
+
]
|
349 |
+
if has_patient_prefix:
|
350 |
+
row.insert(0, patient_info)
|
351 |
+
# remove patient info from sub_dir name
|
352 |
+
video_name = "-".join(sub_dir.name.split("-")[1:]) + f".{video_ext}"
|
353 |
+
row[1] = video_name
|
354 |
+
|
355 |
+
# export results to csv
|
356 |
+
# If file does not exist, this will create it, otherwise it will append to the file
|
357 |
+
df = pd.DataFrame([row], columns=headers)
|
358 |
+
df.to_csv(
|
359 |
+
out_csv_p,
|
360 |
+
mode="a",
|
361 |
+
header=not out_csv_p.exists(),
|
362 |
+
index=False,
|
363 |
+
float_format="%.3f",
|
364 |
+
)
|
365 |
+
|
366 |
+
# join the results with ground truth to add the ground truth column
|
367 |
+
# df_results = pd.read_csv(out_csv_p)
|
368 |
+
# df_gt = pd.read_csv(gt_csv_path)[["video", gt_column_name]] # id & gt columns
|
369 |
+
# df_gt_first = df_gt.drop_duplicates(subset="video", keep="first") # avoid new rows
|
370 |
+
# df_merged = pd.merge(df_results, df_gt_first, on="video", how="left")
|
371 |
+
# df_merged.to_csv(out_csv_p, header=True, index=False, float_format="%.3f")
|
372 |
+
|
373 |
+
# # show stats
|
374 |
+
# value_counts_with_nan = df_merged[gt_column_name].value_counts(dropna=False)
|
375 |
+
# total = len(df_merged)
|
376 |
+
# percentage = (value_counts_with_nan / total) * 100
|
377 |
+
# # Combine value counts and percentages into a DataFrame for better visualization
|
378 |
+
# stats = pd.DataFrame({"Count": value_counts_with_nan, "Percentage": percentage})
|
379 |
+
# logger.info(stats)
|
380 |
+
|
381 |
+
logger.info(f"Done! Results written to {out_csv_p}")
|
382 |
+
|
383 |
+
|
384 |
+
@app.command(help="Copy source images to viz result folder")
|
385 |
+
def copy_srcimg_to_vizdir(
|
386 |
+
src_img_dir: Path = typer.Argument(..., help="Source Images root directory"),
|
387 |
+
out_viz_dir: Path = typer.Argument(..., help="Target viz dirtectory"),
|
388 |
+
):
|
389 |
+
vizs = list(Path(out_viz_dir).glob("**/*.jpg"))
|
390 |
+
for viz in vizs:
|
391 |
+
splits = viz.stem.split("_")
|
392 |
+
ori_img = Path(src_img_dir) / splits[1] / f"{splits[2]}.jpg"
|
393 |
+
shutil.copy(ori_img, Path(out_viz_dir) / f"{viz.stem}_src.jpg")
|
394 |
+
|
395 |
+
|
396 |
+
if __name__ == "__main__":
|
397 |
+
app()
|
plots.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import cv2
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Colors:
|
9 |
+
def __init__(self):
|
10 |
+
# hexs = matplotlib.colors.TABLEAU_COLORS.values()
|
11 |
+
hexs = (
|
12 |
+
"00FF00", # aorta class 0
|
13 |
+
"FF3838",
|
14 |
+
"FF701F",
|
15 |
+
"FFB21D",
|
16 |
+
"CFD231",
|
17 |
+
"48F90A",
|
18 |
+
"92CC17",
|
19 |
+
"3DDB86",
|
20 |
+
"1A9334",
|
21 |
+
"00D4BB",
|
22 |
+
"2C99A8",
|
23 |
+
"00C2FF",
|
24 |
+
"344593",
|
25 |
+
"6473FF",
|
26 |
+
"0018EC",
|
27 |
+
"8438FF",
|
28 |
+
"520085",
|
29 |
+
"CB38FF",
|
30 |
+
"FF95C8",
|
31 |
+
"FF37C7",
|
32 |
+
)
|
33 |
+
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
|
34 |
+
self.n = len(self.palette)
|
35 |
+
|
36 |
+
def __call__(self, i, bgr=False):
|
37 |
+
c = self.palette[int(i) % self.n]
|
38 |
+
return (c[2], c[1], c[0]) if bgr else c
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def hex2rgb(h): # rgb order (PIL)
|
42 |
+
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
43 |
+
|
44 |
+
|
45 |
+
colors = Colors() # create instance for 'from utils.plots import colors'
|
46 |
+
|
47 |
+
|
48 |
+
def is_ascii(s=""):
|
49 |
+
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
50 |
+
s = str(s) # convert list, tuple, None, etc. to str
|
51 |
+
return len(s.encode().decode("ascii", "ignore")) == len(s)
|
52 |
+
|
53 |
+
|
54 |
+
def clip_boxes(boxes, shape):
|
55 |
+
# Clip boxes (xyxy) to image shape (height, width)
|
56 |
+
if isinstance(boxes, torch.Tensor): # faster individually
|
57 |
+
boxes[:, 0].clamp_(0, shape[1]) # x1
|
58 |
+
boxes[:, 1].clamp_(0, shape[0]) # y1
|
59 |
+
boxes[:, 2].clamp_(0, shape[1]) # x2
|
60 |
+
boxes[:, 3].clamp_(0, shape[0]) # y2
|
61 |
+
else: # np.array (faster grouped)
|
62 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
63 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
64 |
+
|
65 |
+
|
66 |
+
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
67 |
+
# Rescale boxes (xyxy) from img1_shape to img0_shape
|
68 |
+
if ratio_pad is None: # calculate from img0_shape
|
69 |
+
gain = min(
|
70 |
+
img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]
|
71 |
+
) # gain = old / new
|
72 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (
|
73 |
+
img1_shape[0] - img0_shape[0] * gain
|
74 |
+
) / 2 # wh padding
|
75 |
+
else:
|
76 |
+
gain = ratio_pad[0][0]
|
77 |
+
pad = ratio_pad[1]
|
78 |
+
|
79 |
+
boxes[:, [0, 2]] -= pad[0] # x padding
|
80 |
+
boxes[:, [1, 3]] -= pad[1] # y padding
|
81 |
+
boxes[:, :4] /= gain
|
82 |
+
clip_boxes(boxes, img0_shape)
|
83 |
+
return boxes
|
84 |
+
|
85 |
+
|
86 |
+
def crop_mask(masks, boxes):
|
87 |
+
"""
|
88 |
+
"Crop" predicted masks by zeroing out everything not in the predicted bbox.
|
89 |
+
Vectorized by Chong (thanks Chong).
|
90 |
+
Args:
|
91 |
+
- masks should be a size [h, w, n] tensor of masks
|
92 |
+
- boxes should be a size [n, 4] tensor of bbox coords in relative point form
|
93 |
+
"""
|
94 |
+
|
95 |
+
n, h, w = masks.shape
|
96 |
+
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
|
97 |
+
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[
|
98 |
+
None, None, :
|
99 |
+
] # rows shape(1,w,1)
|
100 |
+
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[
|
101 |
+
None, :, None
|
102 |
+
] # cols shape(h,1,1)
|
103 |
+
|
104 |
+
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
105 |
+
|
106 |
+
|
107 |
+
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
108 |
+
"""
|
109 |
+
Crop before upsample.
|
110 |
+
proto_out: [mask_dim, mask_h, mask_w]
|
111 |
+
out_masks: [n, mask_dim], n is number of masks after nms
|
112 |
+
bboxes: [n, 4], n is number of masks after nms
|
113 |
+
shape:input_image_size, (h, w)
|
114 |
+
return: h, w, n
|
115 |
+
"""
|
116 |
+
|
117 |
+
c, mh, mw = protos.shape # CHW
|
118 |
+
ih, iw = shape
|
119 |
+
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
|
120 |
+
|
121 |
+
downsampled_bboxes = bboxes.clone()
|
122 |
+
downsampled_bboxes[:, 0] *= mw / iw
|
123 |
+
downsampled_bboxes[:, 2] *= mw / iw
|
124 |
+
downsampled_bboxes[:, 3] *= mh / ih
|
125 |
+
downsampled_bboxes[:, 1] *= mh / ih
|
126 |
+
|
127 |
+
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
128 |
+
if upsample:
|
129 |
+
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[
|
130 |
+
0
|
131 |
+
] # CHW
|
132 |
+
return masks.gt_(0.5)
|
133 |
+
|
134 |
+
|
135 |
+
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
|
136 |
+
"""
|
137 |
+
img1_shape: model input shape, [h, w]
|
138 |
+
img0_shape: origin pic shape, [h, w, 3]
|
139 |
+
masks: [h, w, num]
|
140 |
+
"""
|
141 |
+
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
142 |
+
if ratio_pad is None: # calculate from im0_shape
|
143 |
+
gain = min(
|
144 |
+
im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]
|
145 |
+
) # gain = old / new
|
146 |
+
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (
|
147 |
+
im1_shape[0] - im0_shape[0] * gain
|
148 |
+
) / 2 # wh padding
|
149 |
+
else:
|
150 |
+
pad = ratio_pad[1]
|
151 |
+
top, left = int(pad[1]), int(pad[0]) # y, x
|
152 |
+
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
153 |
+
|
154 |
+
if len(masks.shape) < 2:
|
155 |
+
raise ValueError(
|
156 |
+
f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}'
|
157 |
+
)
|
158 |
+
masks = masks[top:bottom, left:right]
|
159 |
+
# masks = masks.permute(2, 0, 1).contiguous()
|
160 |
+
# masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
|
161 |
+
# masks = masks.permute(1, 2, 0).contiguous()
|
162 |
+
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
|
163 |
+
|
164 |
+
if len(masks.shape) == 2:
|
165 |
+
masks = masks[:, :, None]
|
166 |
+
return masks
|
167 |
+
|
168 |
+
|
169 |
+
class Annotator:
|
170 |
+
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
im,
|
174 |
+
line_width=None,
|
175 |
+
font_size=None,
|
176 |
+
font="Arial.ttf",
|
177 |
+
pil=False,
|
178 |
+
example="abc",
|
179 |
+
):
|
180 |
+
assert (
|
181 |
+
im.data.contiguous
|
182 |
+
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
|
183 |
+
non_ascii = not is_ascii(
|
184 |
+
example
|
185 |
+
) # non-latin labels, i.e. asian, arabic, cyrillic
|
186 |
+
self.pil = pil or non_ascii
|
187 |
+
if self.pil: # use PIL
|
188 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
189 |
+
self.draw = ImageDraw.Draw(self.im)
|
190 |
+
self.font = check_pil_font(
|
191 |
+
font="Arial.Unicode.ttf" if non_ascii else font,
|
192 |
+
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12),
|
193 |
+
)
|
194 |
+
else: # use cv2
|
195 |
+
self.im = im
|
196 |
+
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
197 |
+
|
198 |
+
def box_label(
|
199 |
+
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
|
200 |
+
):
|
201 |
+
# Add one xyxy box to image with label
|
202 |
+
if self.pil or not is_ascii(label):
|
203 |
+
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
204 |
+
if label:
|
205 |
+
w, h = self.font.getsize(label) # text width, height
|
206 |
+
outside = box[1] - h >= 0 # label fits outside box
|
207 |
+
self.draw.rectangle(
|
208 |
+
(
|
209 |
+
box[0],
|
210 |
+
box[1] - h if outside else box[1],
|
211 |
+
box[0] + w + 1,
|
212 |
+
box[1] + 1 if outside else box[1] + h + 1,
|
213 |
+
),
|
214 |
+
fill=color,
|
215 |
+
)
|
216 |
+
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
217 |
+
self.draw.text(
|
218 |
+
(box[0], box[1] - h if outside else box[1]),
|
219 |
+
label,
|
220 |
+
fill=txt_color,
|
221 |
+
font=self.font,
|
222 |
+
)
|
223 |
+
else: # cv2
|
224 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
225 |
+
cv2.rectangle(
|
226 |
+
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA
|
227 |
+
)
|
228 |
+
if label:
|
229 |
+
tf = max(self.lw - 1, 1) # font thickness
|
230 |
+
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[
|
231 |
+
0
|
232 |
+
] # text width, height
|
233 |
+
outside = p1[1] - h >= 3
|
234 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
235 |
+
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
236 |
+
cv2.putText(
|
237 |
+
self.im,
|
238 |
+
label,
|
239 |
+
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
240 |
+
0,
|
241 |
+
self.lw / 3,
|
242 |
+
txt_color,
|
243 |
+
thickness=tf,
|
244 |
+
lineType=cv2.LINE_AA,
|
245 |
+
)
|
246 |
+
|
247 |
+
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
248 |
+
"""Plot masks at once.
|
249 |
+
Args:
|
250 |
+
masks (tensor): predicted masks on cuda, shape: [n, h, w]
|
251 |
+
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
|
252 |
+
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
|
253 |
+
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
254 |
+
"""
|
255 |
+
im_gpu = torch.from_numpy(im_gpu) # not sure why we need this fix?
|
256 |
+
# print(im_gpu)
|
257 |
+
if self.pil:
|
258 |
+
# convert to numpy first
|
259 |
+
self.im = np.asarray(self.im).copy()
|
260 |
+
if len(masks) == 0:
|
261 |
+
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
262 |
+
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
|
263 |
+
colors = colors[:, None, None] # shape(n,1,1,3)
|
264 |
+
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
265 |
+
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
266 |
+
|
267 |
+
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
268 |
+
mcs = (masks_color * inv_alph_masks).sum(
|
269 |
+
0
|
270 |
+
) * 2 # mask color summand shape(n,h,w,3)
|
271 |
+
|
272 |
+
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
273 |
+
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
274 |
+
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
275 |
+
im_mask = (im_gpu * 255).byte().cpu().numpy()
|
276 |
+
self.im[:] = (
|
277 |
+
im_mask
|
278 |
+
if retina_masks
|
279 |
+
else scale_image(im_gpu.shape, im_mask, self.im.shape)
|
280 |
+
)
|
281 |
+
if self.pil:
|
282 |
+
# convert im back to PIL and update draw
|
283 |
+
self.fromarray(self.im)
|
284 |
+
|
285 |
+
def rectangle(self, xy, fill=None, outline=None, width=1):
|
286 |
+
# Add rectangle to image (PIL-only)
|
287 |
+
self.draw.rectangle(xy, fill, outline, width)
|
288 |
+
|
289 |
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"):
|
290 |
+
# Add text to image (PIL-only)
|
291 |
+
if anchor == "bottom": # start y from font bottom
|
292 |
+
w, h = self.font.getsize(text) # text width, height
|
293 |
+
xy[1] += 1 - h
|
294 |
+
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
295 |
+
|
296 |
+
def fromarray(self, im):
|
297 |
+
# Update self.im from a numpy array
|
298 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
299 |
+
self.draw = ImageDraw.Draw(self.im)
|
300 |
+
|
301 |
+
def result(self):
|
302 |
+
# Return annotated image as array
|
303 |
+
return np.asarray(self.im)
|
predict.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Must import torch before onnxruntime, else could not create cuda context
|
2 |
+
# ref: https://github.com/microsoft/onnxruntime/issues/11092#issuecomment-1386840174
|
3 |
+
import torch, torchvision
|
4 |
+
import onnxruntime
|
5 |
+
|
6 |
+
from time import perf_counter
|
7 |
+
from openvino.runtime import Core, Layout, get_batch, AsyncInferQueue
|
8 |
+
from pathlib import Path
|
9 |
+
import yaml
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import time
|
13 |
+
from plots import Annotator, process_mask, scale_boxes, scale_image, colors
|
14 |
+
from loguru import logger
|
15 |
+
|
16 |
+
|
17 |
+
def from_numpy(x):
|
18 |
+
return torch.from_numpy(x) if isinstance(x, np.ndarray) else x
|
19 |
+
|
20 |
+
|
21 |
+
def yaml_load(file="data.yaml"):
|
22 |
+
# Single-line safe yaml loading
|
23 |
+
with open(file, errors="ignore") as f:
|
24 |
+
return yaml.safe_load(f)
|
25 |
+
|
26 |
+
|
27 |
+
def load_metadata(f=Path("path/to/meta.yaml")):
|
28 |
+
# Load metadata from meta.yaml if it exists
|
29 |
+
if f.exists():
|
30 |
+
d = yaml_load(f)
|
31 |
+
return d["stride"], d["names"] # assign stride, names
|
32 |
+
return None, None
|
33 |
+
|
34 |
+
|
35 |
+
def letterbox(
|
36 |
+
im,
|
37 |
+
new_shape=(640, 640),
|
38 |
+
color=(114, 114, 114),
|
39 |
+
auto=True,
|
40 |
+
scale_fill=False,
|
41 |
+
scaleup=True,
|
42 |
+
stride=32,
|
43 |
+
):
|
44 |
+
# Resize and pad image while meeting stride-multiple constraints
|
45 |
+
shape = im.shape[:2] # current shape [height, width]
|
46 |
+
if isinstance(new_shape, int):
|
47 |
+
new_shape = (new_shape, new_shape)
|
48 |
+
|
49 |
+
# Scale ratio (new / old)
|
50 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
51 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
52 |
+
r = min(r, 1.0)
|
53 |
+
|
54 |
+
# Compute padding
|
55 |
+
ratio = r, r # width, height ratios
|
56 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
57 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
58 |
+
if auto: # minimum rectangle
|
59 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
60 |
+
elif scale_fill: # stretch
|
61 |
+
dw, dh = 0.0, 0.0
|
62 |
+
new_unpad = (new_shape[1], new_shape[0])
|
63 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
64 |
+
|
65 |
+
dw /= 2 # divide padding into 2 sides
|
66 |
+
dh /= 2
|
67 |
+
|
68 |
+
if shape[::-1] != new_unpad: # resize
|
69 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
70 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
71 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
72 |
+
im = cv2.copyMakeBorder(
|
73 |
+
im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
74 |
+
) # add border
|
75 |
+
return im, ratio, (dw, dh)
|
76 |
+
|
77 |
+
|
78 |
+
def xywh2xyxy(x):
|
79 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
80 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
81 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
82 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
83 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
84 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
85 |
+
return y
|
86 |
+
|
87 |
+
|
88 |
+
def box_iou(box1, box2, eps=1e-7):
|
89 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
90 |
+
"""
|
91 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
92 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
93 |
+
Arguments:
|
94 |
+
box1 (Tensor[N, 4])
|
95 |
+
box2 (Tensor[M, 4])
|
96 |
+
Returns:
|
97 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
98 |
+
IoU values for every element in boxes1 and boxes2
|
99 |
+
"""
|
100 |
+
|
101 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
102 |
+
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
103 |
+
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
104 |
+
|
105 |
+
# IoU = inter / (area1 + area2 - inter)
|
106 |
+
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
107 |
+
|
108 |
+
|
109 |
+
def non_max_suppression(
|
110 |
+
prediction,
|
111 |
+
conf_thres=0.25,
|
112 |
+
iou_thres=0.45,
|
113 |
+
classes=None,
|
114 |
+
agnostic=False,
|
115 |
+
multi_label=False,
|
116 |
+
labels=(),
|
117 |
+
max_det=300,
|
118 |
+
nm=0, # number of masks
|
119 |
+
redundant=True, # require redundant detections
|
120 |
+
):
|
121 |
+
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
122 |
+
Returns:
|
123 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
124 |
+
"""
|
125 |
+
|
126 |
+
if isinstance(
|
127 |
+
prediction, (list, tuple)
|
128 |
+
): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
129 |
+
prediction = prediction[0] # select only inference output
|
130 |
+
|
131 |
+
device = prediction.device
|
132 |
+
mps = "mps" in device.type # Apple MPS
|
133 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
134 |
+
prediction = prediction.cpu()
|
135 |
+
bs = prediction.shape[0] # batch size
|
136 |
+
nc = prediction.shape[2] - nm - 5 # number of classes
|
137 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
138 |
+
|
139 |
+
# Checks
|
140 |
+
assert (
|
141 |
+
0 <= conf_thres <= 1
|
142 |
+
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
143 |
+
assert (
|
144 |
+
0 <= iou_thres <= 1
|
145 |
+
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
146 |
+
|
147 |
+
# Settings
|
148 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
149 |
+
max_wh = 7680 # (pixels) maximum box width and height
|
150 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
151 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
152 |
+
merge = False # use merge-NMS
|
153 |
+
|
154 |
+
t = time.time()
|
155 |
+
mi = 5 + nc # mask start index
|
156 |
+
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
157 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
158 |
+
# Apply constraints
|
159 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
160 |
+
x = x[xc[xi]] # confidence
|
161 |
+
|
162 |
+
# Cat apriori labels if autolabelling
|
163 |
+
if labels and len(labels[xi]):
|
164 |
+
lb = labels[xi]
|
165 |
+
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
166 |
+
v[:, :4] = lb[:, 1:5] # box
|
167 |
+
v[:, 4] = 1.0 # conf
|
168 |
+
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
|
169 |
+
x = torch.cat((x, v), 0)
|
170 |
+
|
171 |
+
# If none remain process next image
|
172 |
+
if not x.shape[0]:
|
173 |
+
continue
|
174 |
+
|
175 |
+
# Compute conf
|
176 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
177 |
+
|
178 |
+
# Box/Mask
|
179 |
+
box = xywh2xyxy(
|
180 |
+
x[:, :4]
|
181 |
+
) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
182 |
+
mask = x[:, mi:] # zero columns if no masks
|
183 |
+
|
184 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
185 |
+
if multi_label:
|
186 |
+
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
|
187 |
+
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
|
188 |
+
else: # best class only
|
189 |
+
conf, j = x[:, 5:mi].max(1, keepdim=True)
|
190 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
191 |
+
|
192 |
+
# Filter by class
|
193 |
+
if classes is not None:
|
194 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
195 |
+
|
196 |
+
# Apply finite constraint
|
197 |
+
# if not torch.isfinite(x).all():
|
198 |
+
# x = x[torch.isfinite(x).all(1)]
|
199 |
+
|
200 |
+
# Check shape
|
201 |
+
n = x.shape[0] # number of boxes
|
202 |
+
if not n: # no boxes
|
203 |
+
continue
|
204 |
+
elif n > max_nms: # excess boxes
|
205 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
206 |
+
else:
|
207 |
+
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
208 |
+
|
209 |
+
# Batched NMS
|
210 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
211 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
212 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
213 |
+
if i.shape[0] > max_det: # limit detections
|
214 |
+
i = i[:max_det]
|
215 |
+
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
|
216 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
217 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
218 |
+
weights = iou * scores[None] # box weights
|
219 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
|
220 |
+
1, keepdim=True
|
221 |
+
) # merged boxes
|
222 |
+
if redundant:
|
223 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
224 |
+
|
225 |
+
output[xi] = x[i]
|
226 |
+
if mps:
|
227 |
+
output[xi] = output[xi].to(device)
|
228 |
+
|
229 |
+
return output
|
230 |
+
|
231 |
+
|
232 |
+
class Model:
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
model_path,
|
236 |
+
imgsz=320,
|
237 |
+
classes=None,
|
238 |
+
device="CPU",
|
239 |
+
plot_mask=False,
|
240 |
+
conf_thres=0.7,
|
241 |
+
n_jobs=1,
|
242 |
+
is_async=False,
|
243 |
+
):
|
244 |
+
# filter by class: classes=[0], or classes=[0, 2, 3]
|
245 |
+
model_type = "onnx" if Path(model_path).suffix == ".onnx" else "openvino"
|
246 |
+
assert Path(model_path).exists(), f"Model {model_path} not found"
|
247 |
+
assert Path(model_path).suffix in (
|
248 |
+
".onnx",
|
249 |
+
".xml",
|
250 |
+
), "Model must be .onnx or .xml"
|
251 |
+
self.model_type = model_type
|
252 |
+
self.model_path = model_path
|
253 |
+
self.imgsz = imgsz
|
254 |
+
self.classes = classes
|
255 |
+
self.plot_mask = plot_mask
|
256 |
+
self.conf_thres = conf_thres
|
257 |
+
|
258 |
+
# async settings
|
259 |
+
self.n_jobs = n_jobs
|
260 |
+
self.is_async = is_async
|
261 |
+
self.completed_results = {} # key: frame_id, value: inference results
|
262 |
+
self.ori_cv_imgs = {} # key: frame_id, value: original cv image
|
263 |
+
self.prep_cv_imgs = {} # key: frame_id, value: preprocessed cv image
|
264 |
+
|
265 |
+
if self.model_type == "onnx":
|
266 |
+
assert is_async is False, "Async mode is not supported for ONNX models"
|
267 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
268 |
+
session = onnxruntime.InferenceSession(model_path, providers=providers)
|
269 |
+
self.session = session
|
270 |
+
output_names = [x.name for x in session.get_outputs()]
|
271 |
+
self.output_names = output_names
|
272 |
+
meta = session.get_modelmeta().custom_metadata_map # metadata
|
273 |
+
if "stride" in meta:
|
274 |
+
stride, names = int(meta["stride"]), eval(meta["names"])
|
275 |
+
self.stride = stride
|
276 |
+
self.names = names
|
277 |
+
elif self.model_type == "openvino":
|
278 |
+
# load OpenVINO model
|
279 |
+
assert Path(model_path).suffix == ".xml", "OpenVINO model must be .xml"
|
280 |
+
ie = Core()
|
281 |
+
weights = Path(model_path).with_suffix(".bin").as_posix()
|
282 |
+
network = ie.read_model(model=model_path, weights=weights)
|
283 |
+
if network.get_parameters()[0].get_layout().empty:
|
284 |
+
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
285 |
+
batch_dim = get_batch(network)
|
286 |
+
if batch_dim.is_static:
|
287 |
+
batch_size = batch_dim.get_length()
|
288 |
+
|
289 |
+
# To run inference on M1, we must export the IR model using "mo --use_legacy_frontend"
|
290 |
+
# Otherwise, we would get the following error when compiling the model
|
291 |
+
# https://github.com/openvinotoolkit/openvino/issues/12476#issuecomment-1222202804
|
292 |
+
config = {}
|
293 |
+
if n_jobs == "auto":
|
294 |
+
config = {"PERFORMANCE_HINT": "THROUGHPUT"}
|
295 |
+
self.executable_network = ie.compile_model(
|
296 |
+
network, device_name=device, config=config
|
297 |
+
)
|
298 |
+
num_requests = self.executable_network.get_property(
|
299 |
+
"OPTIMAL_NUMBER_OF_INFER_REQUESTS"
|
300 |
+
)
|
301 |
+
self.n_jobs = num_requests if n_jobs == "auto" else int(n_jobs)
|
302 |
+
logger.info(f"Optimal number of infer requests should be: {num_requests}")
|
303 |
+
self.stride, self.names = load_metadata(
|
304 |
+
Path(weights).with_suffix(".yaml")
|
305 |
+
) # load metadata
|
306 |
+
|
307 |
+
if is_async:
|
308 |
+
logger.info(f"Using num of infer requests jobs: {n_jobs}")
|
309 |
+
self.pipeline = AsyncInferQueue(self.executable_network, self.n_jobs)
|
310 |
+
self.pipeline.set_callback(self.callback)
|
311 |
+
|
312 |
+
def preprocess(self, cv_img, pt=False):
|
313 |
+
im = letterbox(cv_img, self.imgsz, stride=self.stride, auto=pt)[
|
314 |
+
0
|
315 |
+
] # padded resize
|
316 |
+
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
317 |
+
im = np.ascontiguousarray(im) # contiguous
|
318 |
+
im = torch.from_numpy(im)
|
319 |
+
im = im.float() # uint8 to fp16/32
|
320 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
321 |
+
if len(im.shape) == 3:
|
322 |
+
im = im[None] # expand for batch dim
|
323 |
+
im = im.cpu().numpy() # torch to numpy
|
324 |
+
return im
|
325 |
+
|
326 |
+
def postprocess(self, y, ori_cv_im, prep_im):
|
327 |
+
y = [from_numpy(x) for x in y]
|
328 |
+
pred, proto = y[0], y[-1]
|
329 |
+
|
330 |
+
im0 = ori_cv_im
|
331 |
+
|
332 |
+
# NMS
|
333 |
+
iou_thres = 0.45
|
334 |
+
agnostic_nms = False
|
335 |
+
max_det = 1 # maximum detections per image, only 1 aorta is needed
|
336 |
+
pred = non_max_suppression(
|
337 |
+
pred,
|
338 |
+
self.conf_thres,
|
339 |
+
iou_thres,
|
340 |
+
self.classes,
|
341 |
+
agnostic_nms,
|
342 |
+
max_det=max_det,
|
343 |
+
nm=32,
|
344 |
+
)
|
345 |
+
|
346 |
+
# Process predictions
|
347 |
+
line_thickness = 3
|
348 |
+
annotator = Annotator(
|
349 |
+
np.ascontiguousarray(im0),
|
350 |
+
line_width=line_thickness,
|
351 |
+
example=str(self.names),
|
352 |
+
)
|
353 |
+
i = 0
|
354 |
+
det = pred[0]
|
355 |
+
im = prep_im
|
356 |
+
r_xyxy, r_conf, r_masks = None, None, None
|
357 |
+
if len(pred[0]):
|
358 |
+
masks = process_mask(
|
359 |
+
proto[i],
|
360 |
+
det[:, 6:],
|
361 |
+
det[:, :4],
|
362 |
+
(self.imgsz, self.imgsz),
|
363 |
+
upsample=True,
|
364 |
+
) # HWC
|
365 |
+
det[:, :4] = scale_boxes(
|
366 |
+
(self.imgsz, self.imgsz), det[:, :4], im0.shape
|
367 |
+
).round() # rescale boxes to im0 size
|
368 |
+
|
369 |
+
# Mask plotting
|
370 |
+
if self.plot_mask:
|
371 |
+
annotator.masks(
|
372 |
+
masks,
|
373 |
+
colors=[colors(x, True) for x in det[:, 5]],
|
374 |
+
im_gpu=im[i],
|
375 |
+
alpha=0.1,
|
376 |
+
)
|
377 |
+
|
378 |
+
# Write results
|
379 |
+
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
|
380 |
+
# Add bbox to image
|
381 |
+
c = int(cls) # integer class
|
382 |
+
label = f"{self.names[c]} {conf:.2f}"
|
383 |
+
annotator.box_label(xyxy, label, color=colors(c, True))
|
384 |
+
r_xyxy = xyxy
|
385 |
+
r_conf = conf
|
386 |
+
r_xyxy = [i.int().numpy().item() for i in r_xyxy]
|
387 |
+
r_conf = r_conf.numpy().item()
|
388 |
+
r_masks = scale_image((self.imgsz, self.imgsz), masks.numpy()[0], im0.shape)
|
389 |
+
return annotator.result(), (r_xyxy, r_conf, r_masks)
|
390 |
+
|
391 |
+
def predict(self, cv_img):
|
392 |
+
# return the annotated image and the bounding box
|
393 |
+
result_cv_img, xyxy = None, None
|
394 |
+
im = self.preprocess(cv_img)
|
395 |
+
if self.model_type == "onnx":
|
396 |
+
y = self.session.run(
|
397 |
+
self.output_names, {self.session.get_inputs()[0].name: im}
|
398 |
+
)
|
399 |
+
elif self.model_type == "openvino":
|
400 |
+
# OpenVINO model inference
|
401 |
+
# Note: Please use FP32 model on M1, otherwise you will get many runtime errors
|
402 |
+
# Very slow on M1, but works
|
403 |
+
# start = perf_counter()
|
404 |
+
y = list(self.executable_network([im]).values())
|
405 |
+
# logger.info(f"OpenVINO inference time: {perf_counter() - start:.3f}s")
|
406 |
+
result_cv_img, others = self.postprocess(y, cv_img, im)
|
407 |
+
return result_cv_img, others
|
408 |
+
|
409 |
+
def callback(self, request, userdata):
|
410 |
+
# callback function for AsyncInferQueue
|
411 |
+
outputs = request.outputs
|
412 |
+
frame_id = userdata
|
413 |
+
self.completed_results[frame_id] = [i.data for i in outputs]
|
414 |
+
|
415 |
+
def predict_async(self, cv_img, frame_id):
|
416 |
+
assert self.is_async, "Please set is_async=True when initializing the model"
|
417 |
+
self.ori_cv_imgs[frame_id] = cv_img
|
418 |
+
im = self.preprocess(cv_img)
|
419 |
+
self.prep_cv_imgs[frame_id] = im
|
420 |
+
|
421 |
+
# Note: The start_async function call is not required to be synchronized - it waits for any available job if the queue is busy/overloaded.
|
422 |
+
# https://docs.openvino.ai/latest/openvino_docs_OV_UG_Python_API_exclusives.html#asyncinferqueue
|
423 |
+
#
|
424 |
+
# idle_id = self.pipeline.get_idle_request_id()
|
425 |
+
# self.pipeline.start_async({idle_id: im}, frame_id)
|
426 |
+
self.pipeline.start_async({0: im}, frame_id)
|
427 |
+
|
428 |
+
def is_free_to_infer_async(self):
|
429 |
+
"""Returns True if any free request in the pool, otherwise False"""
|
430 |
+
assert self.is_async, "Please set is_async=True when initializing the model"
|
431 |
+
return self.pipeline.is_ready()
|
432 |
+
|
433 |
+
def get_result(self, frame_id):
|
434 |
+
"""Returns the inference result for the given frame_id"""
|
435 |
+
assert self.is_async, "Please set is_async=True when initializing the model"
|
436 |
+
if frame_id in self.completed_results:
|
437 |
+
y = self.completed_results.pop(frame_id)
|
438 |
+
cv_img = self.ori_cv_imgs.pop(frame_id)
|
439 |
+
im = self.prep_cv_imgs.pop(frame_id)
|
440 |
+
result_cv_img, others = self.postprocess(y, cv_img, im)
|
441 |
+
return result_cv_img, others
|
442 |
+
return None
|
443 |
+
|
444 |
+
|
445 |
+
if __name__ == "__main__":
|
446 |
+
m_p = "weights/yolov7seg-JH-v1.onnx"
|
447 |
+
m_p = "weights/yolov5s-seg-MK-v1.onnx"
|
448 |
+
m_p = "weights/best_openvino_model/best.xml"
|
449 |
+
imgsz = 320
|
450 |
+
# imgsz = 640
|
451 |
+
model = Model(model_path=m_p, imgsz=imgsz)
|
452 |
+
|
453 |
+
# inference an image using the loaded model
|
454 |
+
# source = 'Tim_3-0-00-20.05.jpg'
|
455 |
+
path = "data/Jimmy_2-0-00-04.63.jpg"
|
456 |
+
assert Path(path).exists(), f"Input image {path} doesn't exist"
|
457 |
+
|
458 |
+
# output path
|
459 |
+
save_dir = "runs/predict"
|
460 |
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
461 |
+
out_p = f"{save_dir}/{Path(path).stem}.jpg"
|
462 |
+
|
463 |
+
# load image and preprocess
|
464 |
+
im0 = cv2.imread(path) # BGR
|
465 |
+
result_cv_img, _ = model.predict(im0)
|
466 |
+
if result_cv_img is not None:
|
467 |
+
cv2.imwrite(out_p, result_cv_img)
|
468 |
+
logger.info(f"Saved result to {out_p}")
|
469 |
+
else:
|
470 |
+
logger.error("No result, something went wrong")
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset processing
|
2 |
+
loguru==0.6.0
|
3 |
+
typer[all]==0.7.0
|
4 |
+
fiftyone==0.19.1
|
5 |
+
pycocotools==2.0.6
|
6 |
+
|
7 |
+
torch==1.13.0
|
8 |
+
torchvision==0.14.0
|
9 |
+
openvino==2022.2.0; sys_platform != "darwin"
|
10 |
+
openvino-arm==2022.1.0.1; sys_platform == "darwin"
|
11 |
+
opencv-python==4.6.0.66
|
12 |
+
PyYAML==6.0
|
13 |
+
onnx==1.13.1
|
14 |
+
onnxruntime==1.13.1
|
15 |
+
onnxruntime-gpu==1.13.1; sys_platform != "darwin"
|
16 |
+
|
17 |
+
# demo GUI
|
18 |
+
PySide6==6.4.1
|
19 |
+
scipy==1.9.3
|
20 |
+
matplotlib==3.5.2
|
21 |
+
|
22 |
+
# demo plot
|
23 |
+
plotly==5.11.0
|
24 |
+
pandas==1.5.2
|
25 |
+
kaleido==0.2.1; platform_system != "Windows"
|
26 |
+
kaleido==0.1.0post1; platform_system == "Windows"
|
roi.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
# get the image from arguments
|
8 |
+
ap = argparse.ArgumentParser()
|
9 |
+
ap.add_argument("-i", "--image", required=True, help="Path to the image")
|
10 |
+
args = vars(ap.parse_args())
|
11 |
+
|
12 |
+
# check if the image does exist
|
13 |
+
img_p = Path(args["image"])
|
14 |
+
assert img_p.exists(), "Image does not exist"
|
15 |
+
|
16 |
+
# Read the image
|
17 |
+
img = cv2.imread(args["image"])
|
18 |
+
|
19 |
+
# Select the ROI from the image
|
20 |
+
ROI = cv2.selectROI("Image", img, False, False)
|
21 |
+
|
22 |
+
# Append the ROI coordinates to the csv file
|
23 |
+
# header: filename, ori_width, ori_height, roi_x, roi_y, roi_width, roi_height
|
24 |
+
with open("roi.csv", "a") as f:
|
25 |
+
# if no file exists, create a new one with the header
|
26 |
+
if f.tell() == 0:
|
27 |
+
f.write("filename,ori_width,ori_height,roi_x,roi_y,roi_width,roi_height\n")
|
28 |
+
ori_w, ori_h = img.shape[1], img.shape[0]
|
29 |
+
f.write(f"{img_p.name},{ori_w},{ori_h},{ROI[0]},{ROI[1]},{ROI[2]},{ROI[3]}\n")
|
30 |
+
|
31 |
+
# Display cropped image
|
32 |
+
cropped = img[ROI[1] : ROI[1] + ROI[3], ROI[0] : ROI[0] + ROI[2]]
|
33 |
+
cv2.imshow("Cropped Image", cropped)
|
34 |
+
cv2.waitKey(0)
|
try_chart.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
usgfw2wrapper.dll
ADDED
Binary file (15.4 kB). View file
|
|
weights/.keep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
weights/yolov5s-v2
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|