|
import streamlit as st |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from torchvision import transforms |
|
from PIL import Image |
|
import torch.nn as nn |
|
|
|
|
|
from PathDino import get_pathDino_model |
|
|
|
import os |
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model, image_transforms = get_pathDino_model("PathDino512.pth") |
|
|
|
|
|
st.sidebar.markdown("### PathDino") |
|
st.sidebar.markdown( |
|
"PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. " |
|
"PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits " |
|
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology " |
|
"image analysis.\n\n" |
|
) |
|
|
|
|
|
default_image_url_compare = "images/HistRotate.png" |
|
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500) |
|
|
|
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png" |
|
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500) |
|
|
|
default_image_url_compare = "images/ActivationMap.png" |
|
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500) |
|
|
|
|
|
st.sidebar.markdown("### Citation") |
|
|
|
st.sidebar.markdown(""" |
|
```markdown |
|
@article{alfasly2023PathDino, |
|
title={Rotation-Agnostic Representation Learning for Histopathological Image Analysis}, |
|
author={Saghir, Alfasly and Abubakr, Shafique and Peyman, Nejat and Jibran, Khan and Areej, Alsaafin and Ghazal, Alabtah and H.R.Tizhoosh}, |
|
journal={arXiv preprint arXiv:xxxx.xxxxx}, |
|
year={2023}""") |
|
|
|
|
|
st.sidebar.markdown("\n\n") |
|
|
|
col1, col2 = st.sidebar.columns(2) |
|
|
|
|
|
url_logo_lab = "images/rhazes_lab_logo.png" |
|
with col1: |
|
st.image(url_logo_lab, width=150) |
|
|
|
url_logo_company = "images/Mayo_Clinic_logo.png" |
|
with col2: |
|
st.image(url_logo_company, width=150) |
|
st.sidebar.markdown("Rhazes Lab, \n Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA") |
|
|
|
|
|
|
|
|
|
def visualize_attention_ViT(model, img, patch_size=16): |
|
attention_list = [] |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
w_featmap = img.shape[-2] // patch_size |
|
h_featmap = img.shape[-1] // patch_size |
|
attentions = model.get_last_selfattention(img.to(device)) |
|
nh = attentions.shape[1] |
|
|
|
attentions = attentions[0, :, 0, 1:].reshape(nh, -1) |
|
attentions = attentions.reshape(nh, w_featmap, h_featmap) |
|
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy() |
|
for j in range(nh): |
|
attention_list.append(attentions[j]) |
|
return attention_list |
|
|
|
|
|
def generate_activation_maps(image): |
|
preprocess = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.CenterCrop(512), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
image_tensor = preprocess(image) |
|
img = image_tensor.unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16) |
|
return attention_list |
|
|
|
|
|
st.title("PathDino - Compact ViT for histopathology Image Analysis") |
|
st.write("Upload a histology image to view the activation maps.") |
|
|
|
|
|
uploaded_image = "images/HistRotate.png" |
|
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) |
|
|
|
if uploaded_image is not None: |
|
columns = st.columns(3) |
|
columns[1].image(uploaded_image, caption="Uploaded Image", width=300) |
|
|
|
|
|
uploaded_image = Image.open(uploaded_image).convert('RGB') |
|
attention_list = generate_activation_maps(uploaded_image) |
|
print(len(attention_list)) |
|
st.subheader(f"Attention Maps of the input image") |
|
columns = st.columns(len(attention_list)//2) |
|
columns2 = st.columns(len(attention_list)//2) |
|
for index, col in enumerate(columns): |
|
|
|
plt.plot(512, 512) |
|
|
|
|
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
|
|
plt.gca().axes.get_xaxis().set_visible(False) |
|
plt.gca().axes.get_yaxis().set_visible(False) |
|
|
|
plt.imshow(attention_list[index]) |
|
col.pyplot(plt) |
|
plt.close() |
|
|
|
for index, col in enumerate(columns2): |
|
|
|
index = index + len(attention_list)//2 |
|
|
|
plt.plot(512, 512) |
|
|
|
|
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
|
|
plt.gca().axes.get_xaxis().set_visible(False) |
|
plt.gca().axes.get_yaxis().set_visible(False) |
|
|
|
plt.imshow(attention_list[index]) |
|
col.pyplot(plt) |
|
plt.close() |