Upload model
Browse files- README.md +199 -0
- config.json +17 -0
- configuration.py +21 -0
- model.safetensors +3 -0
- modeling.py +389 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"CTCropModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration.CTCropConfig",
|
7 |
+
"AutoModel": "modeling.CTCropModel"
|
8 |
+
},
|
9 |
+
"backbone": "mobilenetv3_small_100",
|
10 |
+
"dropout": 0.1,
|
11 |
+
"feature_dim": 1024,
|
12 |
+
"in_chans": 1,
|
13 |
+
"model_type": "ct_crop",
|
14 |
+
"num_classes": 4,
|
15 |
+
"torch_dtype": "float32",
|
16 |
+
"transformers_version": "4.47.0"
|
17 |
+
}
|
configuration.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class CTCropConfig(PretrainedConfig):
|
5 |
+
model_type = "ct_crop"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
backbone="mobilenetv3_small_100",
|
10 |
+
feature_dim=1024,
|
11 |
+
dropout=0.1,
|
12 |
+
num_classes=4,
|
13 |
+
in_chans=1,
|
14 |
+
**kwargs,
|
15 |
+
):
|
16 |
+
self.backbone = backbone
|
17 |
+
self.feature_dim = feature_dim
|
18 |
+
self.dropout = dropout
|
19 |
+
self.num_classes = num_classes
|
20 |
+
self.in_chans = in_chans
|
21 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcbe8cb2ca2ce0530befadb6b16e8ffb11433d2e57906f5526eb63bbca9bffcf
|
3 |
+
size 6159352
|
modeling.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import glob
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from transformers import PreTrainedModel
|
10 |
+
from timm import create_model
|
11 |
+
|
12 |
+
from .configuration import CTCropConfig
|
13 |
+
|
14 |
+
_PYDICOM_AVAILABLE = False
|
15 |
+
try:
|
16 |
+
from pydicom import dcmread
|
17 |
+
|
18 |
+
_PYDICOM_AVAILABLE = True
|
19 |
+
except ModuleNotFoundError:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class CTCropModel(PreTrainedModel):
|
24 |
+
config_class = CTCropConfig
|
25 |
+
|
26 |
+
def __init__(self, config):
|
27 |
+
super().__init__(config)
|
28 |
+
self.backbone = create_model(
|
29 |
+
model_name=config.backbone,
|
30 |
+
pretrained=False,
|
31 |
+
num_classes=0,
|
32 |
+
global_pool="",
|
33 |
+
features_only=False,
|
34 |
+
in_chans=config.in_chans,
|
35 |
+
)
|
36 |
+
self.dropout = nn.Dropout(p=config.dropout)
|
37 |
+
self.linear = nn.Linear(config.feature_dim, config.num_classes)
|
38 |
+
|
39 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
# [0, 255] -> [-1, 1]
|
41 |
+
mini, maxi = 0.0, 255.0
|
42 |
+
x = (x - mini) / (maxi - mini)
|
43 |
+
x = (x - 0.5) * 2.0
|
44 |
+
return x
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]:
|
48 |
+
# applying windowing to CT
|
49 |
+
lower, upper = WL - WW // 2, WL + WW // 2
|
50 |
+
x = np.clip(x, lower, upper)
|
51 |
+
x = (x - lower) / (upper - lower)
|
52 |
+
return (x * 255.0).astype("uint8")
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def validate_windows_type(windows):
|
56 |
+
assert isinstance(windows, tuple) or isinstance(windows, list)
|
57 |
+
if isinstance(windows, tuple):
|
58 |
+
assert len(windows) == 2
|
59 |
+
assert [isinstance(_, int) for _ in windows]
|
60 |
+
elif isinstance(windows, list):
|
61 |
+
assert all([isinstance(_, tuple) for _ in windows])
|
62 |
+
assert all([len(_) == 2 for _ in windows])
|
63 |
+
assert all([isinstance(__, int) for _ in windows for __ in _])
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def determine_dicom_orientation(ds) -> int:
|
67 |
+
iop = ds.ImageOrientationPatient
|
68 |
+
|
69 |
+
# Calculate the direction cosine for the normal vector of the plane
|
70 |
+
normal_vector = np.cross(iop[:3], iop[3:])
|
71 |
+
|
72 |
+
# Determine the plane based on the largest component of the normal vector
|
73 |
+
abs_normal = np.abs(normal_vector)
|
74 |
+
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
|
75 |
+
return 0 # sagittal
|
76 |
+
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
|
77 |
+
return 1 # coronal
|
78 |
+
else:
|
79 |
+
return 2 # axial
|
80 |
+
|
81 |
+
def load_image_from_dicom(
|
82 |
+
self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None
|
83 |
+
) -> np.ndarray:
|
84 |
+
# windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH)
|
85 |
+
# or list of tuples if wishing to generate multi-channel image using
|
86 |
+
# > 1 window
|
87 |
+
if not _PYDICOM_AVAILABLE:
|
88 |
+
raise Exception("`pydicom` is not installed")
|
89 |
+
dicom = dcmread(path)
|
90 |
+
array = dicom.pixel_array.astype("float32")
|
91 |
+
m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept)
|
92 |
+
array = array * m + b
|
93 |
+
if windows is None:
|
94 |
+
return array
|
95 |
+
|
96 |
+
self.validate_windows_type(windows)
|
97 |
+
if isinstance(windows, tuple):
|
98 |
+
windows = [windows]
|
99 |
+
|
100 |
+
arr_list = []
|
101 |
+
for WL, WW in windows:
|
102 |
+
arr_list.append(self.window(array.copy(), WL, WW))
|
103 |
+
|
104 |
+
array = np.stack(arr_list, axis=-1)
|
105 |
+
if array.shape[-1] == 1:
|
106 |
+
array = np.squeeze(array, axis=-1)
|
107 |
+
|
108 |
+
return array
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def is_valid_dicom(
|
112 |
+
ds,
|
113 |
+
fname: str = "",
|
114 |
+
sort_by_instance_number: bool = False,
|
115 |
+
exclude_invalid_dicoms: bool = False,
|
116 |
+
):
|
117 |
+
attributes = [
|
118 |
+
"pixel_array",
|
119 |
+
"RescaleSlope",
|
120 |
+
"RescaleIntercept",
|
121 |
+
]
|
122 |
+
if sort_by_instance_number:
|
123 |
+
attributes.append("InstanceNumber")
|
124 |
+
else:
|
125 |
+
attributes.append("ImagePositionPatient")
|
126 |
+
attributes.append("ImageOrientationPatient")
|
127 |
+
attributes_present = [hasattr(ds, attr) for attr in attributes]
|
128 |
+
valid = all(attributes_present)
|
129 |
+
if not valid and not exclude_invalid_dicoms:
|
130 |
+
raise Exception(
|
131 |
+
f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}"
|
132 |
+
)
|
133 |
+
return valid
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def most_common_element(lst):
|
137 |
+
return max(set(lst), key=lst.count)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def center_crop_or_pad_borders(image, size):
|
141 |
+
height, width = image.shape[:2]
|
142 |
+
new_height, new_width = size
|
143 |
+
if new_height < height:
|
144 |
+
# crop top and bottom
|
145 |
+
crop_top = (height - new_height) // 2
|
146 |
+
crop_bottom = height - new_height - crop_top
|
147 |
+
image = image[crop_top:-crop_bottom]
|
148 |
+
elif new_height > height:
|
149 |
+
# pad top and bottom
|
150 |
+
pad_top = (new_height - height) // 2
|
151 |
+
pad_bottom = new_height - height - pad_top
|
152 |
+
image = np.pad(
|
153 |
+
image,
|
154 |
+
((pad_top, pad_bottom), (0, 0)),
|
155 |
+
mode="constant",
|
156 |
+
constant_values=0,
|
157 |
+
)
|
158 |
+
|
159 |
+
if new_width < width:
|
160 |
+
# crop left and right
|
161 |
+
crop_left = (width - new_width) // 2
|
162 |
+
crop_right = width - new_width - crop_left
|
163 |
+
image = image[:, crop_left:-crop_right]
|
164 |
+
elif new_width > width:
|
165 |
+
# pad left and right
|
166 |
+
pad_left = (new_width - width) // 2
|
167 |
+
pad_right = new_width - width - pad_left
|
168 |
+
image = np.pad(
|
169 |
+
image,
|
170 |
+
((0, 0), (pad_left, pad_right)),
|
171 |
+
mode="constant",
|
172 |
+
constant_values=0,
|
173 |
+
)
|
174 |
+
|
175 |
+
return image
|
176 |
+
|
177 |
+
def load_stack_from_dicom_folder(
|
178 |
+
self,
|
179 |
+
path: str,
|
180 |
+
windows: tuple[int, int] | list[tuple[int, int]] | None = None,
|
181 |
+
dicom_extension: str = ".dcm",
|
182 |
+
sort_by_instance_number: bool = False,
|
183 |
+
exclude_invalid_dicoms: bool = False,
|
184 |
+
fix_unequal_shapes: str = "crop_pad",
|
185 |
+
return_sorted_dicom_files: bool = False,
|
186 |
+
) -> np.ndarray | tuple[np.ndarray, list[str]]:
|
187 |
+
if not _PYDICOM_AVAILABLE:
|
188 |
+
raise Exception("`pydicom` is not installed")
|
189 |
+
dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}"))
|
190 |
+
if len(dicom_files) == 0:
|
191 |
+
raise Exception(
|
192 |
+
f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`"
|
193 |
+
)
|
194 |
+
dicoms = [dcmread(f) for f in dicom_files]
|
195 |
+
dicoms = [
|
196 |
+
(d, dicom_files[idx])
|
197 |
+
for idx, d in enumerate(dicoms)
|
198 |
+
if self.is_valid_dicom(
|
199 |
+
d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms
|
200 |
+
)
|
201 |
+
]
|
202 |
+
# handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True
|
203 |
+
# by only including valid DICOM filenames
|
204 |
+
dicom_files = [_[1] for _ in dicoms]
|
205 |
+
dicoms = [_[0] for _ in dicoms]
|
206 |
+
|
207 |
+
slices = [dcm.pixel_array.astype("float32") for dcm in dicoms]
|
208 |
+
shapes = np.stack([s.shape for s in slices], axis=0)
|
209 |
+
if not np.all(shapes == shapes[0]):
|
210 |
+
unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True)
|
211 |
+
standard_shape = tuple(unique_shapes[np.argmax(counts)])
|
212 |
+
print(
|
213 |
+
f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}"
|
214 |
+
)
|
215 |
+
if fix_unequal_shapes == "crop_pad":
|
216 |
+
slices = [
|
217 |
+
self.center_crop_or_pad_borders(s, standard_shape)
|
218 |
+
if s.shape != standard_shape
|
219 |
+
else s
|
220 |
+
for s in slices
|
221 |
+
]
|
222 |
+
elif fix_unequal_shapes == "resize":
|
223 |
+
slices = [
|
224 |
+
cv2.resize(s, standard_shape) if s.shape != standard_shape else s
|
225 |
+
for s in slices
|
226 |
+
]
|
227 |
+
slices = np.stack(slices, axis=0)
|
228 |
+
# find orientation
|
229 |
+
orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms]
|
230 |
+
# use most common
|
231 |
+
orientation = self.most_common_element(orientation)
|
232 |
+
|
233 |
+
# sort using ImagePositionPatient
|
234 |
+
# orientation is index to use for sorting
|
235 |
+
if sort_by_instance_number:
|
236 |
+
positions = [float(d.InstanceNumber) for d in dicoms]
|
237 |
+
else:
|
238 |
+
positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms]
|
239 |
+
indices = np.argsort(positions)
|
240 |
+
slices = slices[indices]
|
241 |
+
|
242 |
+
# rescale
|
243 |
+
m, b = (
|
244 |
+
[float(d.RescaleSlope) for d in dicoms],
|
245 |
+
[float(d.RescaleIntercept) for d in dicoms],
|
246 |
+
)
|
247 |
+
m, b = self.most_common_element(m), self.most_common_element(b)
|
248 |
+
slices = slices * m + b
|
249 |
+
if windows is not None:
|
250 |
+
self.validate_windows_type(windows)
|
251 |
+
if isinstance(windows, tuple):
|
252 |
+
windows = [windows]
|
253 |
+
|
254 |
+
arr_list = []
|
255 |
+
for WL, WW in windows:
|
256 |
+
arr_list.append(self.window(slices.copy(), WL, WW))
|
257 |
+
|
258 |
+
slices = np.stack(arr_list, axis=-1)
|
259 |
+
if slices.shape[-1] == 1:
|
260 |
+
slices = np.squeeze(slices, axis=-1)
|
261 |
+
|
262 |
+
if return_sorted_dicom_files:
|
263 |
+
return slices, [dicom_files[idx] for idx in indices]
|
264 |
+
return slices
|
265 |
+
|
266 |
+
@staticmethod
|
267 |
+
def preprocess(x: np.ndarray, mode="2d") -> np.ndarray:
|
268 |
+
mode = mode.lower()
|
269 |
+
if mode == "2d":
|
270 |
+
x = cv2.resize(x, (256, 256))
|
271 |
+
if x.ndim == 2:
|
272 |
+
x = x[:, :, np.newaxis]
|
273 |
+
elif mode == "3d":
|
274 |
+
x = np.stack([cv2.resize(s, (256, 256)) for s in x], axis=0)
|
275 |
+
if x.ndim == 3:
|
276 |
+
x = x[:, :, :, np.newaxis]
|
277 |
+
return x
|
278 |
+
|
279 |
+
@staticmethod
|
280 |
+
def add_buffer_to_coords(
|
281 |
+
coords: torch.Tensor,
|
282 |
+
buffer: float | tuple[float, float] = 0.05,
|
283 |
+
empty_threshold: float = 1e-4,
|
284 |
+
):
|
285 |
+
coords = coords.clone()
|
286 |
+
empty = (coords < empty_threshold).all(dim=1)
|
287 |
+
# assumes coords is a torch.Tensor of shape (N, 4) containing
|
288 |
+
# normalized x, y, w, h coordinates
|
289 |
+
# buffer is for EACH SIDE (i.e., 0.05 will add total of 0.1)
|
290 |
+
assert len(coords.shape) == 2
|
291 |
+
assert coords.shape[1] == 4
|
292 |
+
if isinstance(buffer, float):
|
293 |
+
buffer = buffer, buffer
|
294 |
+
assert buffer[0] >= 0 and buffer[1] >= 0
|
295 |
+
assert coords.min() >= 0 and coords.max() <= 1
|
296 |
+
if buffer == 0 or empty.sum() == coords.shape[0]:
|
297 |
+
return coords
|
298 |
+
# convert xywh->xyxy
|
299 |
+
x1, y1, w, h = coords.unbind(1)
|
300 |
+
x2, y2 = x1 + w, y1 + h
|
301 |
+
# since coords are normalized, can use buffer value directly
|
302 |
+
w_buf, h_buf = buffer
|
303 |
+
x1, y1, x2, y2 = x1 - w_buf, y1 - h_buf, x2 + w_buf, y2 + h_buf
|
304 |
+
x1, y1 = torch.clamp_min(x1, 0), torch.clamp_min(y1, 0)
|
305 |
+
x2, y2 = torch.clamp_max(x2, 1), torch.clamp_max(y2, 1)
|
306 |
+
w, h = x2 - x1, y2 - y1
|
307 |
+
coords = torch.stack([x1, y1, w, h], dim=1)
|
308 |
+
coords[empty] = 0
|
309 |
+
assert coords.min() >= 0 and coords.max() <= 1
|
310 |
+
return coords
|
311 |
+
|
312 |
+
def forward(
|
313 |
+
self,
|
314 |
+
x: torch.Tensor,
|
315 |
+
img_shape: torch.Tensor | None = None,
|
316 |
+
add_buffer: float | tuple[float, float] | None = None,
|
317 |
+
) -> torch.Tensor:
|
318 |
+
# if img_shape is provided, will provide rescaled coordinates
|
319 |
+
# otherwise, provide normalized [0, 1] coordinates
|
320 |
+
# coords format is xywh
|
321 |
+
if img_shape is not None:
|
322 |
+
assert (
|
323 |
+
x.size(0) == img_shape.size(0)
|
324 |
+
), f"x.size(0) [{x.size(0)}] must equal img_shape.size(0) [{img_shape.size(0)}]"
|
325 |
+
# img_shape = (batch_dim, 2)
|
326 |
+
# img_shape[:, 0] = height, img_shape[:, 1] = width
|
327 |
+
|
328 |
+
x = self.normalize(x)
|
329 |
+
# avg pooling
|
330 |
+
features = F.adaptive_avg_pool2d(self.backbone(x), 1).flatten(1)
|
331 |
+
coords = self.linear(features).sigmoid()
|
332 |
+
|
333 |
+
if add_buffer is not None:
|
334 |
+
coords = self.add_buffer_to_coords(coords, add_buffer)
|
335 |
+
|
336 |
+
if img_shape is None:
|
337 |
+
return coords
|
338 |
+
|
339 |
+
rescaled_coords = coords.clone()
|
340 |
+
rescaled_coords[:, 0] = rescaled_coords[:, 0] * img_shape[:, 1]
|
341 |
+
rescaled_coords[:, 1] = rescaled_coords[:, 1] * img_shape[:, 0]
|
342 |
+
rescaled_coords[:, 2] = rescaled_coords[:, 2] * img_shape[:, 1]
|
343 |
+
rescaled_coords[:, 3] = rescaled_coords[:, 3] * img_shape[:, 0]
|
344 |
+
return rescaled_coords.int()
|
345 |
+
|
346 |
+
def crop(
|
347 |
+
self,
|
348 |
+
x: np.ndarray,
|
349 |
+
mode: str,
|
350 |
+
device: str | None = None,
|
351 |
+
raw_hu: bool = False,
|
352 |
+
add_buffer: float | tuple[float, float] | None = None,
|
353 |
+
) -> np.ndarray:
|
354 |
+
assert mode in ["2d", "3d"]
|
355 |
+
if device is None:
|
356 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
357 |
+
assert isinstance(x, np.ndarray)
|
358 |
+
assert (
|
359 |
+
x.ndim <= 4 and x.ndim >= 2
|
360 |
+
), f"# of dimensions should be 2, 3, or 4, got {x.ndim}"
|
361 |
+
x0 = x
|
362 |
+
if mode == "2d":
|
363 |
+
x = np.expand_dims(x, axis=0)
|
364 |
+
img_shapes = torch.tensor([_.shape[:2] for _ in x]).to(device)
|
365 |
+
x = self.preprocess(x, mode="3d")
|
366 |
+
if raw_hu:
|
367 |
+
# if input is in Hounsfield units, apply soft tissue window
|
368 |
+
x = self.window(x, WL=50, WW=400)
|
369 |
+
# torchify
|
370 |
+
x = torch.from_numpy(x)
|
371 |
+
x = x.permute(0, 3, 1, 2).float().to(device)
|
372 |
+
if x.size(1) > 1:
|
373 |
+
# if multi-channel, take mean
|
374 |
+
x = x.mean(1, keepdim=True)
|
375 |
+
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
|
376 |
+
# get the union of all slice-wise bounding boxes
|
377 |
+
# exclude empty boxes
|
378 |
+
coords = coords[coords.sum(dim=1) != 0]
|
379 |
+
# if all empty, return original input
|
380 |
+
if coords.shape[0] == 0:
|
381 |
+
print("no foreground detected, returning original input ...")
|
382 |
+
return x0
|
383 |
+
x, y, w, h = coords.unbind(1)
|
384 |
+
# xywh -> xyxy
|
385 |
+
x1, y1, x2, y2 = x, y, x + w, y + h
|
386 |
+
x1, y1 = x1.min().item(), y1.min().item()
|
387 |
+
x2, y2 = x2.max().item(), y2.max().item()
|
388 |
+
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
|
389 |
+
return cropped
|