Sijuade commited on
Commit
982420e
·
1 Parent(s): ee3265a

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -199
app.py DELETED
@@ -1,199 +0,0 @@
1
- import torch, torchvision
2
- from torchvision import transforms
3
- import numpy as np
4
- import gradio as gr
5
- from PIL import Image
6
- from pytorch_grad_cam import GradCAM
7
- from pytorch_grad_cam.utils.image import show_cam_on_image
8
- from model.network import ResNet18
9
- import matplotlib.pyplot as plt
10
- import PIL
11
- import io
12
- from PIL import Image
13
-
14
- from model.network import *
15
- from utils.gradio_utils import *
16
- from augment.augment import *
17
- from dataset.dataset import *
18
-
19
-
20
-
21
- model = ResNet18(20, None)
22
- model = model.load_from_checkpoint("resnet18.ckpt", map_location=torch.device("cpu"))
23
-
24
- dataloader_args = dict(shuffle=True, batch_size=64)
25
- _, test_transforms = get_transforms(mu, std)
26
-
27
- test = CIFAR10Dataset(transform=test_transforms, train=False)
28
- test_loader = torch.utils.data.DataLoader(test, **dataloader_args)
29
-
30
- target_layers = [model.res_block2.conv[-1]]
31
- targets = None
32
- device = torch.device("cpu")
33
-
34
- examples = get_examples()
35
-
36
- def upload_image_inference(input_img, n_top_classes, transparency):
37
-
38
- org_img = input_img.copy()
39
-
40
- input_img = test_transforms(image=org_img)['image']
41
- input_img = input_img.unsqueeze(0)
42
-
43
- outputs = model(input_img)
44
-
45
- softmax = torch.nn.Softmax(dim=0)
46
- o = softmax(outputs.flatten())
47
- confidences = {classes[i]: float(o[i]) for i in range(n_top_classes)}
48
- _, prediction = torch.max(outputs, 1)
49
-
50
- cam = GradCAM(model=model, target_layers=target_layers)
51
-
52
- grayscale_cam = cam(input_tensor=input_img, targets=None)
53
- grayscale_cam = grayscale_cam[0, :]
54
- img = input_img.squeeze(0)
55
- img = inv_normalize(img)
56
-
57
- rgb_img = np.transpose(img.cpu(), (1, 2, 0))
58
- rgb_img = rgb_img.numpy()
59
- visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
60
-
61
- return([confidences, [org_img, grayscale_cam, visualization]])
62
-
63
-
64
- def misclass_gr(num_images, layer_val, transparency):
65
- images_list = misclassified_data[:num_images]
66
-
67
- images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
68
- return(images_list)
69
-
70
-
71
- def class_gr(num_images, layer_val, transparency):
72
- images_list = classified_data[:num_images]
73
-
74
- images_list = [image_to_array(img, layer_val, transparency) for img in images_list]
75
- return(images_list)
76
-
77
-
78
- def image_to_array(input_img, layer_val, transparency=0.6):
79
- input_tensor = input_img[0]
80
-
81
- cam = GradCAM(model=model, target_layers=[model.res_block2.conv[-layer_val]])
82
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
83
- grayscale_cam = grayscale_cam[0, :]
84
-
85
- img = input_tensor.squeeze(0)
86
- img = inv_normalize(img)
87
- rgb_img = np.transpose(img, (1, 2, 0))
88
- rgb_img = rgb_img.numpy()
89
-
90
- visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True,
91
- image_weight=transparency)
92
-
93
- plt.imshow(visualization)
94
- plt.title(r"Correct: " + classes[input_img[1].item()] + '\n' + 'Output: ' + classes[input_img[2].item()])
95
-
96
- with io.BytesIO() as buffer:
97
- plt.savefig(buffer, format = "png")
98
- buffer.seek(0)
99
- image = Image.open(buffer)
100
- ar = np.asarray(image)
101
-
102
- return(ar)
103
-
104
-
105
- def get_misclassified_data(model, device, test_loader):
106
- """
107
- Function to run the model on test set and return misclassified images
108
- :param model: Network Architecture
109
- :param device: CPU/GPU
110
- :param test_loader: DataLoader for test set
111
- """
112
- mis_count = 0
113
- correct_count = 0
114
-
115
- # Prepare the model for evaluation i.e. drop the dropout layer
116
- model.eval()
117
- # List to store misclassified Images
118
- misclassified_data, classified_data = [], []
119
- # Reset the gradients
120
- with torch.no_grad():
121
- # Extract images, labels in a batch
122
- for data, target in test_loader:
123
- # Migrate the data to the device
124
- data, target = data.to(device), target.to(device)
125
- # Extract single image, label from the batch
126
- for image, label in zip(data, target):
127
- # Add batch dimension to the image
128
- image = image.unsqueeze(0)
129
- # Get the model prediction on the image
130
- output = model(image)
131
- # Convert the output from one-hot encoding to a value
132
- pred = output.argmax(dim=1, keepdim=True)
133
- # If prediction is incorrect, append the data
134
- if pred != label:
135
- misclassified_data.append((image, label, pred))
136
- mis_count += 1
137
- else:
138
- classified_data.append((image, label, pred))
139
- correct_count += 1
140
-
141
- if ((mis_count>=20) and (correct_count>=20)):
142
- return ((classified_data, misclassified_data))
143
-
144
-
145
- title = "CIFAR10 trained on ResNet18 (Pytorch Lightning) Model with GradCAM"
146
- description = "A simple Gradio interface to infer on ResNet model, get GradCAM results for existing & new Images"
147
-
148
- with gr.Blocks() as gradcam:
149
- classified_data, misclassified_data = get_misclassified_data(model, device, test_loader)
150
-
151
- gr.Markdown("Make Grad-Cam of uploaded image, or existing images.")
152
- with gr.Tab("Upload New Image"):
153
- upload_input = [gr.Image(shape=(32, 32)),
154
- gr.Number(minimum=0, maximum=10, label='n Top Classes', value=3, precision=0),
155
- gr.Slider(0, 1, label='Transparency', value=0.6)]
156
-
157
- upload_output = [gr.Label(label='Top Classes'),
158
- gr.Gallery(label="Image | CAM | Image+CAM",
159
- show_label=True, min_width=80).style(columns=[3],
160
- rows=[1],
161
- object_fit="contain",
162
- height="auto")]
163
- button1 = gr.Button("Perform Inference")
164
- gr.Examples(
165
- examples=examples,
166
- inputs=upload_input,
167
- outputs=upload_output,
168
- fn=upload_image_inference,
169
- cache_examples=True,
170
- )
171
-
172
-
173
- with gr.Tab("View Class Activate Maps"):
174
- with gr.Row():
175
- with gr.Column():
176
- cam_input21 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Correctly Classified CAM | Num Images'),
177
- gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
178
- gr.Slider(0, 1, value=0.6, label='Transparency')]
179
-
180
- image_output21 = gr.Gallery(label="Images - Grad-CAM (correct)",
181
- show_label=True, min_width=80)
182
- button21 = gr.Button("View Images")
183
-
184
- with gr.Column():
185
- cam_input22 = [gr.Number(minimum=1, maximum=20, precision=0, value=3, label='View Misclassified CAM | Num Images'),
186
- gr.Number(minimum=1, maximum=3, precision=0, value=1, label='(-) Target Layer'),
187
- gr.Slider(0, 1, value=0.6, label='Transparency')]
188
-
189
- image_output22 = gr.Gallery(label="Images - Grad-CAM (Misclassified)",
190
- show_label=True, min_width=80)
191
- button22 = gr.Button("View Images")
192
-
193
- button1.click(upload_image_inference, inputs=upload_input, outputs=upload_output)
194
- button21.click(class_gr, inputs=cam_input21, outputs=image_output21)
195
- button22.click(misclass_gr, inputs=cam_input22, outputs=image_output22)
196
-
197
-
198
-
199
- gradcam.launch()