File size: 2,951 Bytes
a8d9779
 
 
c348c61
a8d9779
 
5662f96
a8d9779
5662f96
c348c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083cd0d
c348c61
270559d
083cd0d
270559d
 
e8da8c2
270559d
 
 
083cd0d
270559d
c348c61
 
 
 
 
a8d9779
c348c61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import streamlit as st
import torch
from PIL import Image

from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.transform_utils import make_test_transforms

st.title("PdiscoFormer Part Discovery Visualizer")
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
                 "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
                 "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
                 "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50",
                 "ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4",
                 "ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4",
                 "ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41",
                 "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"]
model_name = st.selectbox("Select a model", model_options)
if model_name is not None:
    if "cub" in model_name or "nabirds" in model_name:
        image_size = 518
    else:
        image_size = 224
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load the model
    model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device)
    num_parts = model.num_landmarks
    amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts)
    test_transforms = make_test_transforms(image_size)

    # Instructions
    st.write("If you choose to upload an image, the attention map will be displayed.")
    st.write("The attention map will highlight the regions of the image that the model is focusing on.")
    st.write("The model is trained to focus on different parts of the salient objects in the image based on the dataset.")
    st.write("If you choose one of the CUB or NABirds models, please choose a bird image.")
    st.write("If you choose one of the Flower models, please choose a flower image.")
    st.write("If you choose one of the PartImageNet models, please choose an images of classes from PartImageNet like land animals/birds/cars/bottles/airplanes.")

    image_name = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])  # Upload an image
    if image_name is not None:
        image = Image.open(image_name).convert("RGB")
        image_tensor = test_transforms(image).unsqueeze(0).to(device)
        with torch.no_grad():
            maps, scores = model(image_tensor)

        coloured_map = amap_vis.show_maps(image_tensor, maps)
        st.image(coloured_map, caption="Attention Map", use_column_width=True)