pdiscoformer / app.py
ananthu-aniraj's picture
try different text based on model choice
228e0ae
raw
history blame
2.68 kB
import streamlit as st
import torch
from PIL import Image
from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.transform_utils import make_test_transforms
st.title("PdiscoFormer Part Discovery Visualizer")
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50",
"ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4",
"ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4",
"ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"]
model_name = st.selectbox("Select a model", model_options)
if model_name is not None:
if "cub" in model_name or "nabirds" in model_name:
image_size = 518
else:
image_size = 224
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device)
num_parts = model.num_landmarks
amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts)
test_transforms = make_test_transforms(image_size)
# Instructions
if "cub" or "nabirds" in model_name:
upload_text = "Upload an image of a bird to visualize the attention maps"
elif "flowers" in model_name:
upload_text = "Upload an image of a flower to visualize the attention maps"
else:
upload_text = "Upload an image of any PartImageNet class (land animals + fish + cars + airplanes) to visualize the attention maps"
image_name = st.file_uploader(upload_text, type=["jpg", "jpeg", "png"]) # Upload an image
if image_name is not None:
image = Image.open(image_name).convert("RGB")
image_tensor = test_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
maps, scores = model(image_tensor)
coloured_map = amap_vis.show_maps(image_tensor, maps)
st.image(coloured_map, caption="Attention Map", use_column_width=True)