Spaces:
Build error
Build error
Commit
·
7c91758
1
Parent(s):
67cbdeb
Upload 2 files
Browse files- Util/Custom_Model.py +44 -0
- Util/DICOM.py +81 -0
Util/Custom_Model.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import timm
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class Build_Custom_Model(nn.Module):
|
| 5 |
+
def __init__(self, model_name, target_size, pretrained=False):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1)
|
| 8 |
+
if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"):
|
| 9 |
+
self.n_features = self.model.head.in_features
|
| 10 |
+
self.model.head = nn.Linear(self.n_features, target_size)
|
| 11 |
+
if(model_name=="resnet34d"):
|
| 12 |
+
self.n_features = self.model.fc.in_features
|
| 13 |
+
self.model.fc = nn.Linear(self.n_features, target_size)
|
| 14 |
+
if(model_name=="resnet18d"):
|
| 15 |
+
self.n_features = self.model.fc.in_features
|
| 16 |
+
self.model.fc = nn.Linear(self.n_features, target_size)
|
| 17 |
+
if(model_name=="tf_efficientnet_b7_ns"):
|
| 18 |
+
self.n_features = self.model.classifier.in_features
|
| 19 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
| 20 |
+
if(model_name=="tf_efficientnet_b0_ns"):
|
| 21 |
+
self.n_features = self.model.classifier.in_features
|
| 22 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
| 23 |
+
if(model_name=="tf_efficientnet_lite0"):
|
| 24 |
+
self.n_features = self.model.classifier.in_features
|
| 25 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
| 26 |
+
if(model_name=="mobilenetv2_050"):
|
| 27 |
+
self.n_features = self.model.classifier.in_features
|
| 28 |
+
self.model.classifier = nn.Linear(self.n_features, target_size)
|
| 29 |
+
if(model_name=="eca_nfnet_l0"):
|
| 30 |
+
self.n_features = self.model.head.fc.in_features
|
| 31 |
+
self.model.head.fc = nn.Linear(self.n_features, target_size)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
output = self.model(x)
|
| 35 |
+
return output
|
| 36 |
+
|
| 37 |
+
def reshape_transform(tensor, height=7, width=7):
|
| 38 |
+
result = tensor.reshape(tensor.size(0),
|
| 39 |
+
height, width, tensor.size(2))
|
| 40 |
+
|
| 41 |
+
# Bring the channels to the first dimension,
|
| 42 |
+
# like in CNNs.
|
| 43 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
| 44 |
+
return result
|
Util/DICOM.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class DICOM_Utils(object):
|
| 5 |
+
def apply_windowing(image_array, window_center, window_width):
|
| 6 |
+
"""
|
| 7 |
+
Apply windowing to a DICOM image array.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
- image_array: numpy array of the DICOM image
|
| 11 |
+
- window_center: center of the window
|
| 12 |
+
- window_width: width of the window
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
- Windowed image array
|
| 16 |
+
"""
|
| 17 |
+
lower_bound = window_center - (window_width / 2)
|
| 18 |
+
upper_bound = window_center + (window_width / 2)
|
| 19 |
+
|
| 20 |
+
# Apply windowing
|
| 21 |
+
windowed_image = image_array.copy()
|
| 22 |
+
windowed_image[windowed_image < lower_bound] = lower_bound
|
| 23 |
+
windowed_image[windowed_image > upper_bound] = upper_bound
|
| 24 |
+
|
| 25 |
+
# Normalize to [0, 255]
|
| 26 |
+
windowed_image = ((windowed_image - lower_bound) / window_width) * 255
|
| 27 |
+
|
| 28 |
+
return windowed_image.astype('uint8')
|
| 29 |
+
|
| 30 |
+
def transform_image_for_display(image_array):
|
| 31 |
+
"""
|
| 32 |
+
Transform the image for display: Flip horizontally and then rotate 90 degrees to the right.
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
- image_array: numpy array of the image
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
- Transformed image array
|
| 39 |
+
"""
|
| 40 |
+
# Flip horizontally
|
| 41 |
+
flipped_image = np.fliplr(image_array)
|
| 42 |
+
|
| 43 |
+
# Rotate 90 degrees to the right
|
| 44 |
+
rotated_image = np.rot90(flipped_image, 1)
|
| 45 |
+
|
| 46 |
+
return rotated_image
|
| 47 |
+
|
| 48 |
+
def apply_CAM_overlay(heatmap, windowed_image, overlay_alpha=0.4):
|
| 49 |
+
"""
|
| 50 |
+
Apply CAM (Class Activation Map) overlay to a given image.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
- heatmap: torch.Tensor, the heatmap generated by CAM.
|
| 54 |
+
- windowed_image: numpy.ndarray, the windowed image to overlay the heatmap on.
|
| 55 |
+
- overlay_alpha: float, the transparency for overlaying heatmap. Default is 0.4.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
- overlayed: numpy.ndarray, the resulting image after overlaying the heatmap.
|
| 59 |
+
"""
|
| 60 |
+
# Convert the heatmap tensor to a numpy array
|
| 61 |
+
heatmap_np = heatmap.cpu().numpy().squeeze()
|
| 62 |
+
|
| 63 |
+
# Normalize the heatmap to [0, 255]
|
| 64 |
+
heatmap_normalized = ((heatmap_np - heatmap_np.min()) /
|
| 65 |
+
(heatmap_np.max() - heatmap_np.min()) * 255).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
# Convert the normalized heatmap to a colormap (for example, using the "jet" colormap)
|
| 68 |
+
heatmap_colormap = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_JET)
|
| 69 |
+
|
| 70 |
+
# Resize the colormap to the original image size
|
| 71 |
+
heatmap_resized = cv2.resize(heatmap_colormap,
|
| 72 |
+
(windowed_image.shape[1], windowed_image.shape[0]))
|
| 73 |
+
|
| 74 |
+
# Convert the grayscale windowed_image to 3 channels
|
| 75 |
+
windowed_image_colored = cv2.cvtColor(windowed_image, cv2.COLOR_GRAY2BGR)
|
| 76 |
+
|
| 77 |
+
# Overlay the heatmap on the original image with a certain transparency
|
| 78 |
+
overlayed = cv2.addWeighted(windowed_image_colored, 1 - overlay_alpha,
|
| 79 |
+
heatmap_resized, overlay_alpha, 0)
|
| 80 |
+
|
| 81 |
+
return overlayed
|