ananthu-aniraj commited on
Commit
c348c61
·
1 Parent(s): 2e619b6

try adding all models

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -1,28 +1,41 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
 
4
  from models import IndividualLandmarkViT
5
  from utils import VisualizeAttentionMaps
6
  from utils.data_utils.transform_utils import make_test_transforms
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- st.title("Pdiscoformer Part Discovery Visualizer for CUB-200-2011/birds (K=8)")
10
- # Set the device
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- # Load the model
13
- model = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8").eval().to(device)
14
-
15
- amap_vis = VisualizeAttentionMaps(num_parts=9, bg_label=8)
16
-
17
- image_size = 518
18
- test_transforms = make_test_transforms(image_size)
19
- image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) # Upload an image
20
- if image_name is not None:
21
- image = Image.open(image_name).convert("RGB")
22
- image_tensor = test_transforms(image).unsqueeze(0).to(device)
23
- with torch.no_grad():
24
- maps, scores = model(image_tensor)
25
-
26
- coloured_map = amap_vis.show_maps(image_tensor, maps)
27
- st.image(coloured_map, caption="Attention Map", use_column_width=True)
28
-
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
+
5
  from models import IndividualLandmarkViT
6
  from utils import VisualizeAttentionMaps
7
  from utils.data_utils.transform_utils import make_test_transforms
8
 
9
+ st.title("Pdiscoformer Part Discovery Visualizer")
10
+ model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16",
11
+ "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8",
12
+ "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25",
13
+ "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50",
14
+ "ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4",
15
+ "ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4",
16
+ "ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11",
17
+ "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16",
18
+ "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41",
19
+ "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"]
20
+ model_name = st.selectbox("Select a model", model_options)
21
+ if model_name is not None:
22
+ if "cub" in model_name or "nabirds" in model_name:
23
+ image_size = 518
24
+ else:
25
+ image_size = 224
26
+ # Set the device
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ # Load the model
29
+ model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device)
30
+ num_parts = model.num_landmarks
31
+ amap_vis = VisualizeAttentionMaps(num_parts=num_parts+1, bg_label=num_parts)
32
+ test_transforms = make_test_transforms(image_size)
33
+ image_name = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) # Upload an image
34
+ if image_name is not None:
35
+ image = Image.open(image_name).convert("RGB")
36
+ image_tensor = test_transforms(image).unsqueeze(0).to(device)
37
+ with torch.no_grad():
38
+ maps, scores = model(image_tensor)
39
 
40
+ coloured_map = amap_vis.show_maps(image_tensor, maps)
41
+ st.image(coloured_map, caption="Attention Map", use_column_width=True)