haritsahm commited on
Commit
cc64157
·
1 Parent(s): 861e32a

Add main deployment script

Browse files
Files changed (1) hide show
  1. main.py +95 -0
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from models import phc_models
10
+ from utils import utils
11
+
12
+ BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
13
+ BILATERAL_MODEL = phc_models.PHCResNet18(
14
+ channels=2, n=2, num_classes=1, visualize=True)
15
+ BILATERAL_MODEL.add_top_blocks(num_classes=1)
16
+ BILATERAL_MODEL.load_state_dict(torch.load(
17
+ BILATERIAL_WEIGHT, map_location='cpu'))
18
+ BILATERAL_MODEL = BILATERAL_MODEL.to('cpu')
19
+ BILATERAL_MODEL.eval()
20
+
21
+ OUTPUT_GALLERY = gr.Gallery(
22
+ label='Highlighted Area').style(grid=[2], height='auto')
23
+
24
+
25
+ def predict_bilateral(file: str) -> List:
26
+ """Predict Bilateral Mammography.
27
+
28
+ Parameters
29
+ ----------
30
+ file : TemporaryFileWrapper
31
+ TemporaryFile object for the uploaded file
32
+
33
+ Returns
34
+ -------
35
+ List[List, Dict]
36
+ List of objects that will be used to display the result
37
+ """
38
+ displays_imgs = []
39
+
40
+ image = np.array(Image.open(file.name))/257
41
+ image = np.reshape(image, (2, image.shape[0]//2, image.shape[1]))
42
+
43
+ im_h, im_w = image[0].shape[:2]
44
+
45
+ image_t = torch.from_numpy(image)
46
+ image_t = image_t.unsqueeze(0) # Add batch dimension
47
+
48
+ out, _, out_refiner = BILATERAL_MODEL(image_t)
49
+
50
+ out_refiner = utils.mean_activations(out_refiner).numpy()
51
+
52
+ probability = torch.sigmoid(out).detach().cpu().item()
53
+ label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign'
54
+ lebels_dict = {label_name: probability}
55
+
56
+ refined_view_norm = cv2.normalize(
57
+ out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
58
+ refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET)
59
+ refined_view = cv2.resize(
60
+ refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
61
+
62
+ image0_colored = cv2.normalize(
63
+ image[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
64
+ image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB)
65
+ image1_colored = cv2.normalize(
66
+ image[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
67
+ image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB)
68
+
69
+ heatmap0_overlay = cv2.addWeighted(
70
+ image0_colored, 1.0, refined_view, 0.5, 0)
71
+ heatmap1_overlay = cv2.addWeighted(
72
+ image1_colored, 1.0, refined_view, 0.5, 0)
73
+
74
+ displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')]
75
+
76
+ displays_imgs.append((heatmap0_overlay, 'CC Interest Area'))
77
+ displays_imgs.append((heatmap1_overlay, 'MLO Interest Area'))
78
+
79
+ return displays_imgs, lebels_dict
80
+
81
+
82
+ def run():
83
+ """Run Gradio App."""
84
+ demo = gr.Interface(
85
+ fn=predict_bilateral,
86
+ inputs=gr.File(file_count='single', file_types=['.png']),
87
+ outputs=[OUTPUT_GALLERY, gr.Label(label='Cancer Type')]
88
+ )
89
+
90
+ demo.launch(server_name='0.0.0.0', server_port=7860)
91
+ demo.close()
92
+
93
+
94
+ if __name__ == '__main__':
95
+ run()