Commit
·
7c91758
1
Parent(s):
67cbdeb
Upload 2 files
Browse files- Util/Custom_Model.py +44 -0
- Util/DICOM.py +81 -0
Util/Custom_Model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Build_Custom_Model(nn.Module):
|
5 |
+
def __init__(self, model_name, target_size, pretrained=False):
|
6 |
+
super().__init__()
|
7 |
+
self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1)
|
8 |
+
if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"):
|
9 |
+
self.n_features = self.model.head.in_features
|
10 |
+
self.model.head = nn.Linear(self.n_features, target_size)
|
11 |
+
if(model_name=="resnet34d"):
|
12 |
+
self.n_features = self.model.fc.in_features
|
13 |
+
self.model.fc = nn.Linear(self.n_features, target_size)
|
14 |
+
if(model_name=="resnet18d"):
|
15 |
+
self.n_features = self.model.fc.in_features
|
16 |
+
self.model.fc = nn.Linear(self.n_features, target_size)
|
17 |
+
if(model_name=="tf_efficientnet_b7_ns"):
|
18 |
+
self.n_features = self.model.classifier.in_features
|
19 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
20 |
+
if(model_name=="tf_efficientnet_b0_ns"):
|
21 |
+
self.n_features = self.model.classifier.in_features
|
22 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
23 |
+
if(model_name=="tf_efficientnet_lite0"):
|
24 |
+
self.n_features = self.model.classifier.in_features
|
25 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
26 |
+
if(model_name=="mobilenetv2_050"):
|
27 |
+
self.n_features = self.model.classifier.in_features
|
28 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
29 |
+
if(model_name=="eca_nfnet_l0"):
|
30 |
+
self.n_features = self.model.head.fc.in_features
|
31 |
+
self.model.head.fc = nn.Linear(self.n_features, target_size)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
output = self.model(x)
|
35 |
+
return output
|
36 |
+
|
37 |
+
def reshape_transform(tensor, height=7, width=7):
|
38 |
+
result = tensor.reshape(tensor.size(0),
|
39 |
+
height, width, tensor.size(2))
|
40 |
+
|
41 |
+
# Bring the channels to the first dimension,
|
42 |
+
# like in CNNs.
|
43 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
44 |
+
return result
|
Util/DICOM.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class DICOM_Utils(object):
|
5 |
+
def apply_windowing(image_array, window_center, window_width):
|
6 |
+
"""
|
7 |
+
Apply windowing to a DICOM image array.
|
8 |
+
|
9 |
+
Parameters:
|
10 |
+
- image_array: numpy array of the DICOM image
|
11 |
+
- window_center: center of the window
|
12 |
+
- window_width: width of the window
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
- Windowed image array
|
16 |
+
"""
|
17 |
+
lower_bound = window_center - (window_width / 2)
|
18 |
+
upper_bound = window_center + (window_width / 2)
|
19 |
+
|
20 |
+
# Apply windowing
|
21 |
+
windowed_image = image_array.copy()
|
22 |
+
windowed_image[windowed_image < lower_bound] = lower_bound
|
23 |
+
windowed_image[windowed_image > upper_bound] = upper_bound
|
24 |
+
|
25 |
+
# Normalize to [0, 255]
|
26 |
+
windowed_image = ((windowed_image - lower_bound) / window_width) * 255
|
27 |
+
|
28 |
+
return windowed_image.astype('uint8')
|
29 |
+
|
30 |
+
def transform_image_for_display(image_array):
|
31 |
+
"""
|
32 |
+
Transform the image for display: Flip horizontally and then rotate 90 degrees to the right.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
- image_array: numpy array of the image
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
- Transformed image array
|
39 |
+
"""
|
40 |
+
# Flip horizontally
|
41 |
+
flipped_image = np.fliplr(image_array)
|
42 |
+
|
43 |
+
# Rotate 90 degrees to the right
|
44 |
+
rotated_image = np.rot90(flipped_image, 1)
|
45 |
+
|
46 |
+
return rotated_image
|
47 |
+
|
48 |
+
def apply_CAM_overlay(heatmap, windowed_image, overlay_alpha=0.4):
|
49 |
+
"""
|
50 |
+
Apply CAM (Class Activation Map) overlay to a given image.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
- heatmap: torch.Tensor, the heatmap generated by CAM.
|
54 |
+
- windowed_image: numpy.ndarray, the windowed image to overlay the heatmap on.
|
55 |
+
- overlay_alpha: float, the transparency for overlaying heatmap. Default is 0.4.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
- overlayed: numpy.ndarray, the resulting image after overlaying the heatmap.
|
59 |
+
"""
|
60 |
+
# Convert the heatmap tensor to a numpy array
|
61 |
+
heatmap_np = heatmap.cpu().numpy().squeeze()
|
62 |
+
|
63 |
+
# Normalize the heatmap to [0, 255]
|
64 |
+
heatmap_normalized = ((heatmap_np - heatmap_np.min()) /
|
65 |
+
(heatmap_np.max() - heatmap_np.min()) * 255).astype(np.uint8)
|
66 |
+
|
67 |
+
# Convert the normalized heatmap to a colormap (for example, using the "jet" colormap)
|
68 |
+
heatmap_colormap = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_JET)
|
69 |
+
|
70 |
+
# Resize the colormap to the original image size
|
71 |
+
heatmap_resized = cv2.resize(heatmap_colormap,
|
72 |
+
(windowed_image.shape[1], windowed_image.shape[0]))
|
73 |
+
|
74 |
+
# Convert the grayscale windowed_image to 3 channels
|
75 |
+
windowed_image_colored = cv2.cvtColor(windowed_image, cv2.COLOR_GRAY2BGR)
|
76 |
+
|
77 |
+
# Overlay the heatmap on the original image with a certain transparency
|
78 |
+
overlayed = cv2.addWeighted(windowed_image_colored, 1 - overlay_alpha,
|
79 |
+
heatmap_resized, overlay_alpha, 0)
|
80 |
+
|
81 |
+
return overlayed
|