Spaces:
Sleeping
Sleeping
File size: 2,828 Bytes
a8d9779 c348c61 a8d9779 5662f96 a8d9779 5662f96 c348c61 083cd0d c348c61 270559d 083cd0d 270559d 083cd0d 270559d c348c61 a8d9779 c348c61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import streamlit as st
import torch
from PIL import Image
from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.transform_utils import make_test_transforms
st.title("PdiscoFormer Part Discovery Visualizer")
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50",
"ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4",
"ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4",
"ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41",
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"]
model_name = st.selectbox("Select a model", model_options)
if model_name is not None:
if "cub" in model_name or "nabirds" in model_name:
image_size = 518
else:
image_size = 224
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device)
num_parts = model.num_landmarks
amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts)
test_transforms = make_test_transforms(image_size)
# Instructions
st.write("If you choose to upload an image, the attention map will be displayed.")
st.write("The attention map will highlight the regions of the image that the model is focusing on.")
st.write("If you choose one of the CUB or NABirds models, please choose a bird image.")
st.write("If you choose one of the Flower models, please choose a flower image.")
st.write("If you choose one of the PartImageNet models, please choose an images of classes from PartImageNet like land animals/birds/cars/bottles/airplanes.")
image_name = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) # Upload an image
if image_name is not None:
image = Image.open(image_name).convert("RGB")
image_tensor = test_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
maps, scores = model(image_tensor)
coloured_map = amap_vis.show_maps(image_tensor, maps)
st.image(coloured_map, caption="Attention Map", use_column_width=True)
|