File size: 1,036 Bytes
a8d9779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import streamlit as st
import torch
from PIL import Image
from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.data_utils.transform_utils import make_test_transforms


st.title("Pdiscoformer Part Discovery Visualizer")
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8").eval().to(device)

amap_vis = VisualizeAttentionMaps(num_parts=9, bg_label=8)

image_size = 518
test_transforms = make_test_transforms(image_size)
image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])  # Upload an image
if image_name is not None:
    image = Image.open(image_name).convert("RGB")
    image_tensor = test_transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        maps, scores = model(image_tensor)

    coloured_map = amap_vis.show_maps(image_tensor, maps)
    st.image(coloured_map, caption="Attention Map", use_column_width=True)