File size: 7,217 Bytes
be2c585 28de83f be2c585 45532a7 3e2e594 be2c585 45532a7 be2c585 b5401fe be2c585 b5401fe be2c585 b5401fe 45532a7 162e915 45532a7 5bbcc33 45532a7 5bbcc33 45532a7 8176d59 45532a7 be2c585 28de83f be8436c 28de83f be2c585 28de83f be8436c 28de83f be2c585 0da0f9a 1c4129f be2c585 be8436c be2c585 be8436c be2c585 28de83f be8436c be2c585 be8436c be2c585 be8436c be2c585 be8436c be2c585 28de83f be8436c 28de83f be2c585 be8436c be2c585 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import numpy as np
from PathDino import get_pathDino_model
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")
st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
"PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. "
"PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits "
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
"image analysis.\n\n"
)
default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=300)
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=300)
default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=300)
st.sidebar.markdown("### Citation")
# Create a code block for citations
st.sidebar.markdown("""
```markdown
@article{alfasly2023rotationagnostic,
title={Rotation-Agnostic Image Representation Learning for Digital Pathology},
author={Saghir Alfasly and Abubakr Shafique and Peyman Nejat and Jibran Khan and Areej Alsaafin and Ghazal Alabtah and H. R. Tizhoosh},
year={2023},
eprint={2311.08359},
archivePrefix={arXiv},
primaryClass={cs.CV}
}""")
# Mayo Clinic Logos
st.sidebar.markdown("\n\n")
st.sidebar.markdown("KIMIA Lab, Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA")
def visualize_attention_ViT(model, img, patch_size=16):
attention_list = []
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size
attentions = model.get_last_selfattention(img.to(device))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
for j in range(nh):
attention_list.append(attentions[j])
return attention_list
# Define the function to generate activation maps
def generate_activation_maps(image, patch_size=16):
# Convert the image to a NumPy array
img = np.array(image)
# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
print("w, h:", w, h)
# min_size = min(w, h)
print("Image shape:", img.shape)
preprocess = transforms.Compose([
transforms.Resize((img.shape[0], img.shape[1])),
transforms.CenterCrop((w, h)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
])
image_tensor = preprocess(image)
img = image_tensor.unsqueeze(0).to(device)
# Generate activation maps
with torch.no_grad():
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
return attention_list
# Streamlit UI
st.title("PathDino - Compact ViT for Histopathology Image Analysis")
st.write("Upload a histology image to view the attention maps.")
# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
# columns = st.columns(3)
st.image(uploaded_image, caption="Uploaded Image", width=500)
# Load the image and apply preprocessing
uploaded_image = Image.open(uploaded_image).convert('RGB')
attention_list = generate_activation_maps(uploaded_image)
print(len(attention_list))
st.subheader(f"Attention Maps of the input image")
columns = st.columns(2)
columns2 = st.columns(2)
columns3 = st.columns(2)
# for index in range(6):
for index, col in enumerate(columns):
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
print(type(attention_list[index]))
print(attention_list[index].shape)
plt.imshow(attention_list[index])
col.pyplot(plt)
# col
# st.image(plt, caption=f"Head-{index+1}", width=display_w)
plt.close()
for index, col in enumerate(columns2):
index = index + 2
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()
for index, col in enumerate(columns3):
index = index + 4
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close() |