File size: 4,247 Bytes
56fb801
 
 
 
 
 
 
 
 
 
 
 
 
82a65d2
 
56fb801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bd781a
ea7721e
 
56fb801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82a65d2
56fb801
 
 
 
82a65d2
56fb801
82a65d2
56fb801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82a65d2
 
 
56fb801
 
82a65d2
56fb801
 
 
 
 
82a65d2
56fb801
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (C) 2020-2021, François-Guillaume Fernandez.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import requests
import streamlit as st
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from torchvision import models
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image

from torchcam import methods
from torchcam.methods._utils import locate_candidate_layer
from torchcam.utils import overlay_mask


CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
TV_MODELS = ["resnet18", "resnet50", "mobilenet_v2", "mobilenet_v3_small", "mobilenet_v3_large"]
LABEL_MAP = requests.get(
    "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
).json()


def main():

    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("TorchCAM: class activation explorer")
    # For newline
    st.write('\n')
    st.write('Check the project at: https://github.com/frgfm/torch-cam')
    # For newline
    st.write('\n')
    # Set the columns
    cols = st.columns((1, 1, 1))
    cols[0].header("Input image")
    cols[1].header("Raw CAM")
    cols[-1].header("Overlayed CAM")

    # Sidebar
    # File selection
    st.sidebar.title("Input selection")
    # Disabling warning
    st.set_option('deprecation.showfileUploaderEncoding', False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg'])
    if uploaded_file is not None:
        img = Image.open(BytesIO(uploaded_file.read()), mode='r').convert('RGB')

        cols[0].image(img, use_column_width=True)

    # Model selection
    st.sidebar.title("Setup")
    tv_model = st.sidebar.selectbox("Classification model", TV_MODELS)
    default_layer = ""
    if tv_model is not None:
        with st.spinner('Loading model...'):
            model = models.__dict__[tv_model](pretrained=True).eval()
        default_layer = locate_candidate_layer(model, (3, 224, 224))

    target_layer = st.sidebar.text_input("Target layer", default_layer)
    cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
    if cam_method is not None:
        cam_extractor = methods.__dict__[cam_method](
            model,
            target_layer=target_layer.split("+") if len(target_layer) > 0 else None
        )

    class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
    class_selection = st.sidebar.selectbox("Class selection", ["Predicted class (argmax)"] + class_choices)

    # For newline
    st.sidebar.write('\n')

    if st.sidebar.button("Compute CAM"):

        if uploaded_file is None:
            st.sidebar.error("Please upload an image first")

        else:
            with st.spinner('Analyzing...'):

                # Preprocess image
                img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

                # Forward the image to the model
                out = model(img_tensor.unsqueeze(0))
                # Select the target class
                if class_selection == "Predicted class (argmax)":
                    class_idx = out.squeeze(0).argmax().item()
                else:
                    class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
                # Retrieve the CAM
                cams = cam_extractor(class_idx, out)
                # Fuse the CAMs if there are several
                cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
                # Plot the raw heatmap
                fig, ax = plt.subplots()
                ax.imshow(cam.numpy())
                ax.axis('off')
                cols[1].pyplot(fig)

                # Overlayed CAM
                fig, ax = plt.subplots()
                result = overlay_mask(img, to_pil_image(cam, mode='F'), alpha=0.5)
                ax.imshow(result)
                ax.axis('off')
                cols[-1].pyplot(fig)


if __name__ == '__main__':
    main()