divya-22 commited on
Commit
65f3782
·
1 Parent(s): f78deec

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from facenet_pytorch import MTCNN, InceptionResnetV1
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import zipfile
9
+ import cv2
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ from pytorch_grad_cam.utils.image import show_cam_on_image
13
+
14
+ with zipfile.ZipFile("examples.zip","r") as zip_ref:
15
+ zip_ref.extractall(".")
16
+
17
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
18
+
19
+ mtcnn = MTCNN(
20
+ select_largest=False,
21
+ post_process=False,
22
+ device=DEVICE
23
+ ).to(DEVICE).eval()
24
+
25
+ model = InceptionResnetV1(
26
+ pretrained="vggface2",
27
+ classify=True,
28
+ num_classes=1,
29
+ device=DEVICE
30
+ )
31
+
32
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
33
+ model.load_state_dict(checkpoint['model_state_dict'])
34
+ model.to(DEVICE)
35
+ model.eval()
36
+
37
+ EXAMPLES_FOLDER = 'examples'
38
+ examples_names = os.listdir(EXAMPLES_FOLDER)
39
+ examples = []
40
+ for example_name in examples_names:
41
+ example_path = os.path.join(EXAMPLES_FOLDER, example_name)
42
+ label = example_name.split('_')[0]
43
+ example = {
44
+ 'path': example_path,
45
+ 'label': label
46
+ }
47
+ examples.append(example)
48
+ np.random.shuffle(examples) # shuffle
49
+
50
+ def predict(input_image:Image.Image, true_label:str):
51
+ """Predict the label of the input_image"""
52
+ face = mtcnn(input_image)
53
+ if face is None:
54
+ raise Exception('No face detected')
55
+ face = face.unsqueeze(0) # add the batch dimension
56
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
57
+
58
+ # convert the face into a numpy array to be able to plot it
59
+ prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
60
+ prev_face = prev_face.astype('uint8')
61
+
62
+ face = face.to(DEVICE)
63
+ face = face.to(torch.float32)
64
+ face = face / 255.0
65
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
66
+
67
+ target_layers=[model.block8.branch1[-1]]
68
+ use_cuda = True if torch.cuda.is_available() else False
69
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
70
+ targets = [ClassifierOutputTarget(0)]
71
+
72
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
73
+ grayscale_cam = grayscale_cam[0, :]
74
+ visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
75
+ face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
76
+
77
+ with torch.no_grad():
78
+ output = torch.sigmoid(model(face).squeeze(0))
79
+ prediction = "real" if output.item() < 0.5 else "fake"
80
+
81
+ real_prediction = 1 - output.item()
82
+ fake_prediction = output.item()
83
+
84
+ confidences = {
85
+ 'real': real_prediction,
86
+ 'fake': fake_prediction
87
+ }
88
+ return confidences, true_label, face_with_mask
89
+
90
+ interface = gr.Interface(
91
+ fn=predict,
92
+ inputs=[
93
+ gr.inputs.Image(label="Input Image", type="pil"),
94
+ "text"
95
+ ],
96
+ outputs=[
97
+ gr.outputs.Label(label="Class"),
98
+ "text",
99
+ gr.outputs.Image(label="Face with Explainability")
100
+ ],
101
+ examples=[[examples[i]["path"], examples[i]["label"]] for i in range(10)]
102
+ ).launch()