ianpan commited on
Commit
4cbe18f
·
verified ·
1 Parent(s): 862883c

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +17 -0
  3. configuration.py +21 -0
  4. model.safetensors +3 -0
  5. modeling.py +389 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CTCropModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.CTCropConfig",
7
+ "AutoModel": "modeling.CTCropModel"
8
+ },
9
+ "backbone": "mobilenetv3_small_100",
10
+ "dropout": 0.1,
11
+ "feature_dim": 1024,
12
+ "in_chans": 1,
13
+ "model_type": "ct_crop",
14
+ "num_classes": 4,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.47.0"
17
+ }
configuration.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CTCropConfig(PretrainedConfig):
5
+ model_type = "ct_crop"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone="mobilenetv3_small_100",
10
+ feature_dim=1024,
11
+ dropout=0.1,
12
+ num_classes=4,
13
+ in_chans=1,
14
+ **kwargs,
15
+ ):
16
+ self.backbone = backbone
17
+ self.feature_dim = feature_dim
18
+ self.dropout = dropout
19
+ self.num_classes = num_classes
20
+ self.in_chans = in_chans
21
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcbe8cb2ca2ce0530befadb6b16e8ffb11433d2e57906f5526eb63bbca9bffcf
3
+ size 6159352
modeling.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import glob
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from transformers import PreTrainedModel
10
+ from timm import create_model
11
+
12
+ from .configuration import CTCropConfig
13
+
14
+ _PYDICOM_AVAILABLE = False
15
+ try:
16
+ from pydicom import dcmread
17
+
18
+ _PYDICOM_AVAILABLE = True
19
+ except ModuleNotFoundError:
20
+ pass
21
+
22
+
23
+ class CTCropModel(PreTrainedModel):
24
+ config_class = CTCropConfig
25
+
26
+ def __init__(self, config):
27
+ super().__init__(config)
28
+ self.backbone = create_model(
29
+ model_name=config.backbone,
30
+ pretrained=False,
31
+ num_classes=0,
32
+ global_pool="",
33
+ features_only=False,
34
+ in_chans=config.in_chans,
35
+ )
36
+ self.dropout = nn.Dropout(p=config.dropout)
37
+ self.linear = nn.Linear(config.feature_dim, config.num_classes)
38
+
39
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
40
+ # [0, 255] -> [-1, 1]
41
+ mini, maxi = 0.0, 255.0
42
+ x = (x - mini) / (maxi - mini)
43
+ x = (x - 0.5) * 2.0
44
+ return x
45
+
46
+ @staticmethod
47
+ def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]:
48
+ # applying windowing to CT
49
+ lower, upper = WL - WW // 2, WL + WW // 2
50
+ x = np.clip(x, lower, upper)
51
+ x = (x - lower) / (upper - lower)
52
+ return (x * 255.0).astype("uint8")
53
+
54
+ @staticmethod
55
+ def validate_windows_type(windows):
56
+ assert isinstance(windows, tuple) or isinstance(windows, list)
57
+ if isinstance(windows, tuple):
58
+ assert len(windows) == 2
59
+ assert [isinstance(_, int) for _ in windows]
60
+ elif isinstance(windows, list):
61
+ assert all([isinstance(_, tuple) for _ in windows])
62
+ assert all([len(_) == 2 for _ in windows])
63
+ assert all([isinstance(__, int) for _ in windows for __ in _])
64
+
65
+ @staticmethod
66
+ def determine_dicom_orientation(ds) -> int:
67
+ iop = ds.ImageOrientationPatient
68
+
69
+ # Calculate the direction cosine for the normal vector of the plane
70
+ normal_vector = np.cross(iop[:3], iop[3:])
71
+
72
+ # Determine the plane based on the largest component of the normal vector
73
+ abs_normal = np.abs(normal_vector)
74
+ if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
75
+ return 0 # sagittal
76
+ elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
77
+ return 1 # coronal
78
+ else:
79
+ return 2 # axial
80
+
81
+ def load_image_from_dicom(
82
+ self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None
83
+ ) -> np.ndarray:
84
+ # windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH)
85
+ # or list of tuples if wishing to generate multi-channel image using
86
+ # > 1 window
87
+ if not _PYDICOM_AVAILABLE:
88
+ raise Exception("`pydicom` is not installed")
89
+ dicom = dcmread(path)
90
+ array = dicom.pixel_array.astype("float32")
91
+ m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept)
92
+ array = array * m + b
93
+ if windows is None:
94
+ return array
95
+
96
+ self.validate_windows_type(windows)
97
+ if isinstance(windows, tuple):
98
+ windows = [windows]
99
+
100
+ arr_list = []
101
+ for WL, WW in windows:
102
+ arr_list.append(self.window(array.copy(), WL, WW))
103
+
104
+ array = np.stack(arr_list, axis=-1)
105
+ if array.shape[-1] == 1:
106
+ array = np.squeeze(array, axis=-1)
107
+
108
+ return array
109
+
110
+ @staticmethod
111
+ def is_valid_dicom(
112
+ ds,
113
+ fname: str = "",
114
+ sort_by_instance_number: bool = False,
115
+ exclude_invalid_dicoms: bool = False,
116
+ ):
117
+ attributes = [
118
+ "pixel_array",
119
+ "RescaleSlope",
120
+ "RescaleIntercept",
121
+ ]
122
+ if sort_by_instance_number:
123
+ attributes.append("InstanceNumber")
124
+ else:
125
+ attributes.append("ImagePositionPatient")
126
+ attributes.append("ImageOrientationPatient")
127
+ attributes_present = [hasattr(ds, attr) for attr in attributes]
128
+ valid = all(attributes_present)
129
+ if not valid and not exclude_invalid_dicoms:
130
+ raise Exception(
131
+ f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}"
132
+ )
133
+ return valid
134
+
135
+ @staticmethod
136
+ def most_common_element(lst):
137
+ return max(set(lst), key=lst.count)
138
+
139
+ @staticmethod
140
+ def center_crop_or_pad_borders(image, size):
141
+ height, width = image.shape[:2]
142
+ new_height, new_width = size
143
+ if new_height < height:
144
+ # crop top and bottom
145
+ crop_top = (height - new_height) // 2
146
+ crop_bottom = height - new_height - crop_top
147
+ image = image[crop_top:-crop_bottom]
148
+ elif new_height > height:
149
+ # pad top and bottom
150
+ pad_top = (new_height - height) // 2
151
+ pad_bottom = new_height - height - pad_top
152
+ image = np.pad(
153
+ image,
154
+ ((pad_top, pad_bottom), (0, 0)),
155
+ mode="constant",
156
+ constant_values=0,
157
+ )
158
+
159
+ if new_width < width:
160
+ # crop left and right
161
+ crop_left = (width - new_width) // 2
162
+ crop_right = width - new_width - crop_left
163
+ image = image[:, crop_left:-crop_right]
164
+ elif new_width > width:
165
+ # pad left and right
166
+ pad_left = (new_width - width) // 2
167
+ pad_right = new_width - width - pad_left
168
+ image = np.pad(
169
+ image,
170
+ ((0, 0), (pad_left, pad_right)),
171
+ mode="constant",
172
+ constant_values=0,
173
+ )
174
+
175
+ return image
176
+
177
+ def load_stack_from_dicom_folder(
178
+ self,
179
+ path: str,
180
+ windows: tuple[int, int] | list[tuple[int, int]] | None = None,
181
+ dicom_extension: str = ".dcm",
182
+ sort_by_instance_number: bool = False,
183
+ exclude_invalid_dicoms: bool = False,
184
+ fix_unequal_shapes: str = "crop_pad",
185
+ return_sorted_dicom_files: bool = False,
186
+ ) -> np.ndarray | tuple[np.ndarray, list[str]]:
187
+ if not _PYDICOM_AVAILABLE:
188
+ raise Exception("`pydicom` is not installed")
189
+ dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}"))
190
+ if len(dicom_files) == 0:
191
+ raise Exception(
192
+ f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`"
193
+ )
194
+ dicoms = [dcmread(f) for f in dicom_files]
195
+ dicoms = [
196
+ (d, dicom_files[idx])
197
+ for idx, d in enumerate(dicoms)
198
+ if self.is_valid_dicom(
199
+ d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms
200
+ )
201
+ ]
202
+ # handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True
203
+ # by only including valid DICOM filenames
204
+ dicom_files = [_[1] for _ in dicoms]
205
+ dicoms = [_[0] for _ in dicoms]
206
+
207
+ slices = [dcm.pixel_array.astype("float32") for dcm in dicoms]
208
+ shapes = np.stack([s.shape for s in slices], axis=0)
209
+ if not np.all(shapes == shapes[0]):
210
+ unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True)
211
+ standard_shape = tuple(unique_shapes[np.argmax(counts)])
212
+ print(
213
+ f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}"
214
+ )
215
+ if fix_unequal_shapes == "crop_pad":
216
+ slices = [
217
+ self.center_crop_or_pad_borders(s, standard_shape)
218
+ if s.shape != standard_shape
219
+ else s
220
+ for s in slices
221
+ ]
222
+ elif fix_unequal_shapes == "resize":
223
+ slices = [
224
+ cv2.resize(s, standard_shape) if s.shape != standard_shape else s
225
+ for s in slices
226
+ ]
227
+ slices = np.stack(slices, axis=0)
228
+ # find orientation
229
+ orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms]
230
+ # use most common
231
+ orientation = self.most_common_element(orientation)
232
+
233
+ # sort using ImagePositionPatient
234
+ # orientation is index to use for sorting
235
+ if sort_by_instance_number:
236
+ positions = [float(d.InstanceNumber) for d in dicoms]
237
+ else:
238
+ positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms]
239
+ indices = np.argsort(positions)
240
+ slices = slices[indices]
241
+
242
+ # rescale
243
+ m, b = (
244
+ [float(d.RescaleSlope) for d in dicoms],
245
+ [float(d.RescaleIntercept) for d in dicoms],
246
+ )
247
+ m, b = self.most_common_element(m), self.most_common_element(b)
248
+ slices = slices * m + b
249
+ if windows is not None:
250
+ self.validate_windows_type(windows)
251
+ if isinstance(windows, tuple):
252
+ windows = [windows]
253
+
254
+ arr_list = []
255
+ for WL, WW in windows:
256
+ arr_list.append(self.window(slices.copy(), WL, WW))
257
+
258
+ slices = np.stack(arr_list, axis=-1)
259
+ if slices.shape[-1] == 1:
260
+ slices = np.squeeze(slices, axis=-1)
261
+
262
+ if return_sorted_dicom_files:
263
+ return slices, [dicom_files[idx] for idx in indices]
264
+ return slices
265
+
266
+ @staticmethod
267
+ def preprocess(x: np.ndarray, mode="2d") -> np.ndarray:
268
+ mode = mode.lower()
269
+ if mode == "2d":
270
+ x = cv2.resize(x, (256, 256))
271
+ if x.ndim == 2:
272
+ x = x[:, :, np.newaxis]
273
+ elif mode == "3d":
274
+ x = np.stack([cv2.resize(s, (256, 256)) for s in x], axis=0)
275
+ if x.ndim == 3:
276
+ x = x[:, :, :, np.newaxis]
277
+ return x
278
+
279
+ @staticmethod
280
+ def add_buffer_to_coords(
281
+ coords: torch.Tensor,
282
+ buffer: float | tuple[float, float] = 0.05,
283
+ empty_threshold: float = 1e-4,
284
+ ):
285
+ coords = coords.clone()
286
+ empty = (coords < empty_threshold).all(dim=1)
287
+ # assumes coords is a torch.Tensor of shape (N, 4) containing
288
+ # normalized x, y, w, h coordinates
289
+ # buffer is for EACH SIDE (i.e., 0.05 will add total of 0.1)
290
+ assert len(coords.shape) == 2
291
+ assert coords.shape[1] == 4
292
+ if isinstance(buffer, float):
293
+ buffer = buffer, buffer
294
+ assert buffer[0] >= 0 and buffer[1] >= 0
295
+ assert coords.min() >= 0 and coords.max() <= 1
296
+ if buffer == 0 or empty.sum() == coords.shape[0]:
297
+ return coords
298
+ # convert xywh->xyxy
299
+ x1, y1, w, h = coords.unbind(1)
300
+ x2, y2 = x1 + w, y1 + h
301
+ # since coords are normalized, can use buffer value directly
302
+ w_buf, h_buf = buffer
303
+ x1, y1, x2, y2 = x1 - w_buf, y1 - h_buf, x2 + w_buf, y2 + h_buf
304
+ x1, y1 = torch.clamp_min(x1, 0), torch.clamp_min(y1, 0)
305
+ x2, y2 = torch.clamp_max(x2, 1), torch.clamp_max(y2, 1)
306
+ w, h = x2 - x1, y2 - y1
307
+ coords = torch.stack([x1, y1, w, h], dim=1)
308
+ coords[empty] = 0
309
+ assert coords.min() >= 0 and coords.max() <= 1
310
+ return coords
311
+
312
+ def forward(
313
+ self,
314
+ x: torch.Tensor,
315
+ img_shape: torch.Tensor | None = None,
316
+ add_buffer: float | tuple[float, float] | None = None,
317
+ ) -> torch.Tensor:
318
+ # if img_shape is provided, will provide rescaled coordinates
319
+ # otherwise, provide normalized [0, 1] coordinates
320
+ # coords format is xywh
321
+ if img_shape is not None:
322
+ assert (
323
+ x.size(0) == img_shape.size(0)
324
+ ), f"x.size(0) [{x.size(0)}] must equal img_shape.size(0) [{img_shape.size(0)}]"
325
+ # img_shape = (batch_dim, 2)
326
+ # img_shape[:, 0] = height, img_shape[:, 1] = width
327
+
328
+ x = self.normalize(x)
329
+ # avg pooling
330
+ features = F.adaptive_avg_pool2d(self.backbone(x), 1).flatten(1)
331
+ coords = self.linear(features).sigmoid()
332
+
333
+ if add_buffer is not None:
334
+ coords = self.add_buffer_to_coords(coords, add_buffer)
335
+
336
+ if img_shape is None:
337
+ return coords
338
+
339
+ rescaled_coords = coords.clone()
340
+ rescaled_coords[:, 0] = rescaled_coords[:, 0] * img_shape[:, 1]
341
+ rescaled_coords[:, 1] = rescaled_coords[:, 1] * img_shape[:, 0]
342
+ rescaled_coords[:, 2] = rescaled_coords[:, 2] * img_shape[:, 1]
343
+ rescaled_coords[:, 3] = rescaled_coords[:, 3] * img_shape[:, 0]
344
+ return rescaled_coords.int()
345
+
346
+ def crop(
347
+ self,
348
+ x: np.ndarray,
349
+ mode: str,
350
+ device: str | None = None,
351
+ raw_hu: bool = False,
352
+ add_buffer: float | tuple[float, float] | None = None,
353
+ ) -> np.ndarray:
354
+ assert mode in ["2d", "3d"]
355
+ if device is None:
356
+ device = "cuda" if torch.cuda.is_available() else "cpu"
357
+ assert isinstance(x, np.ndarray)
358
+ assert (
359
+ x.ndim <= 4 and x.ndim >= 2
360
+ ), f"# of dimensions should be 2, 3, or 4, got {x.ndim}"
361
+ x0 = x
362
+ if mode == "2d":
363
+ x = np.expand_dims(x, axis=0)
364
+ img_shapes = torch.tensor([_.shape[:2] for _ in x]).to(device)
365
+ x = self.preprocess(x, mode="3d")
366
+ if raw_hu:
367
+ # if input is in Hounsfield units, apply soft tissue window
368
+ x = self.window(x, WL=50, WW=400)
369
+ # torchify
370
+ x = torch.from_numpy(x)
371
+ x = x.permute(0, 3, 1, 2).float().to(device)
372
+ if x.size(1) > 1:
373
+ # if multi-channel, take mean
374
+ x = x.mean(1, keepdim=True)
375
+ coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
376
+ # get the union of all slice-wise bounding boxes
377
+ # exclude empty boxes
378
+ coords = coords[coords.sum(dim=1) != 0]
379
+ # if all empty, return original input
380
+ if coords.shape[0] == 0:
381
+ print("no foreground detected, returning original input ...")
382
+ return x0
383
+ x, y, w, h = coords.unbind(1)
384
+ # xywh -> xyxy
385
+ x1, y1, x2, y2 = x, y, x + w, y + h
386
+ x1, y1 = x1.min().item(), y1.min().item()
387
+ x2, y2 = x2.max().item(), y2.max().item()
388
+ cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
389
+ return cropped