File size: 2,669 Bytes
a8d9779
 
 
c348c61
a8d9779
 
5662f96
a8d9779
5662f96
c348c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083cd0d
c348c61
083cd0d
 
 
 
 
 
 
 
c348c61
 
 
 
 
 
a8d9779
c348c61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
import torch
from PIL import Image

from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.transform_utils import make_test_transforms

st.title("PdiscoFormer Part Discovery Visualizer")
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
                 "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
                 "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
                 "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50",
                 "ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4",
                 "ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4",
                 "ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"]
model_name = st.selectbox("Select a model", model_options)
if model_name is not None:
    if "cub" in model_name or "nabirds" in model_name:
        image_size = 518
    else:
        image_size = 224
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load the model
    model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device)
    num_parts = model.num_landmarks
    amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts)
    test_transforms = make_test_transforms(image_size)
    # Instructions
    if "cub" or "nabirds" in model_name:
        st.write("Upload an image of a bird to visualize the attention maps")
    elif "flowers" in model_name:
        st.write("Upload an image of a flower to visualize the attention maps")
    else:
        st.write("Upload an image of any PartImageNet class (land animals + fish + cars + airplanes) to visualize the attention maps")

    image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])  # Upload an image
    if image_name is not None:
        image = Image.open(image_name).convert("RGB")
        image_tensor = test_transforms(image).unsqueeze(0).to(device)
        with torch.no_grad():
            maps, scores = model(image_tensor)

        coloured_map = amap_vis.show_maps(image_tensor, maps)
        st.image(coloured_map, caption="Attention Map", use_column_width=True)