PathDino / app.py
Saghir's picture
Update app.py
45532a7
raw
history blame
6.43 kB
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from PathDino import get_pathDino_model
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")
st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
"PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. "
"PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits "
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
"image analysis.\n\n"
)
default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500)
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500)
default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500)
st.sidebar.markdown("### Citation")
# Create a code block for citations
st.sidebar.markdown("""
```markdown
@article{alfasly2023PathDino,
title={Rotation-Agnostic Representation Learning for Histopathological Image Analysis},
author={Saghir, Alfasly and Abubakr, Shafique and Peyman, Nejat and Jibran, Khan and Areej, Alsaafin and Ghazal, Alabtah and H.R.Tizhoosh},
journal={arXiv preprint arXiv:xxxx.xxxxx},
year={2023}""")
# Rhazes Lab and Mayo Clinic Logos
st.sidebar.markdown("\n\n")
# Create a two-column layout for logos and text
col1, col2 = st.sidebar.columns(2)
# Logo and text for My Logo Lab
url_logo_lab = "images/rhazes_lab_logo.png"
with col1:
st.image(url_logo_lab, width=150)
# Logo and text for My Logo Company
url_logo_company = "images/Mayo_Clinic_logo.png"
with col2:
st.image(url_logo_company, width=150)
st.sidebar.markdown("Rhazes Lab, \n Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA")
def visualize_attention_ViT(model, img, patch_size=16):
attention_list = []
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size
attentions = model.get_last_selfattention(img.to(device))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
for j in range(nh):
attention_list.append(attentions[j])
return attention_list
# Define the function to generate activation maps
def generate_activation_maps(image):
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
])
image_tensor = preprocess(image)
img = image_tensor.unsqueeze(0).to(device)
# Generate activation maps
with torch.no_grad():
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
return attention_list
# Streamlit UI
st.title("PathDino - Compact ViT for histopathology Image Analysis")
st.write("Upload a histology image to view the activation maps.")
# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
columns = st.columns(3)
columns[1].image(uploaded_image, caption="Uploaded Image", width=300)
# Load the image and apply preprocessing
uploaded_image = Image.open(uploaded_image).convert('RGB')
attention_list = generate_activation_maps(uploaded_image)
print(len(attention_list))
st.subheader(f"Attention Maps of the input image")
columns = st.columns(len(attention_list)//2)
columns2 = st.columns(len(attention_list)//2)
for index, col in enumerate(columns):
# Create a plot
plt.plot(512, 512)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()
for index, col in enumerate(columns2):
index = index + len(attention_list)//2
# Create a plot
plt.plot(512, 512)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()