pdiscoformer / app.py
ananthu-aniraj's picture
upload initial version
a8d9779
raw
history blame
1.04 kB
import streamlit as st
import torch
from PIL import Image
from models import IndividualLandmarkViT
from utils import VisualizeAttentionMaps
from utils.data_utils.transform_utils import make_test_transforms
st.title("Pdiscoformer Part Discovery Visualizer")
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8").eval().to(device)
amap_vis = VisualizeAttentionMaps(num_parts=9, bg_label=8)
image_size = 518
test_transforms = make_test_transforms(image_size)
image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) # Upload an image
if image_name is not None:
image = Image.open(image_name).convert("RGB")
image_tensor = test_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
maps, scores = model(image_tensor)
coloured_map = amap_vis.show_maps(image_tensor, maps)
st.image(coloured_map, caption="Attention Map", use_column_width=True)