Upload 3 files
Browse files- .gitattributes +1 -0
- app.py +147 -0
- base_implementation +3 -0
- requirements.txt +13 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
base_implementation filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import skimage.io
|
4 |
+
import numpy as np # linear algebra
|
5 |
+
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import DataLoader, Dataset
|
12 |
+
from efficientnet_pytorch import model as enet
|
13 |
+
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from tqdm import tqdm_notebook as tqdm
|
16 |
+
tile_size = 256
|
17 |
+
image_size = 256
|
18 |
+
n_tiles = 36
|
19 |
+
batch_size = 8
|
20 |
+
num_workers = 4
|
21 |
+
class enetv2(nn.Module):
|
22 |
+
def __init__(self, backbone, out_dim):
|
23 |
+
super(enetv2, self).__init__()
|
24 |
+
self.enet = enet.EfficientNet.from_name(backbone)
|
25 |
+
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
|
26 |
+
self.enet._fc = nn.Identity()
|
27 |
+
|
28 |
+
def extract(self, x):
|
29 |
+
return self.enet(x)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.extract(x)
|
33 |
+
x = self.myfc(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def load_models(model_files):
|
38 |
+
models = []
|
39 |
+
for model_f in model_files:
|
40 |
+
model_f = os.path.join(model_dir, model_f)
|
41 |
+
backbone = 'efficientnet-b0'
|
42 |
+
model = enetv2(backbone, out_dim=5)
|
43 |
+
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
|
44 |
+
model.eval()
|
45 |
+
model.to(device)
|
46 |
+
models.append(model)
|
47 |
+
print(f'{model_f} loaded!')
|
48 |
+
return models
|
49 |
+
|
50 |
+
|
51 |
+
model_files = [
|
52 |
+
'base_implementation'
|
53 |
+
]
|
54 |
+
|
55 |
+
models = load_models(model_files)
|
56 |
+
|
57 |
+
def get_tiles(img, mode=0):
|
58 |
+
result = []
|
59 |
+
h, w, c = img.shape
|
60 |
+
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
|
61 |
+
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)
|
62 |
+
|
63 |
+
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
|
64 |
+
img3 = img2.reshape(
|
65 |
+
img2.shape[0] // tile_size,
|
66 |
+
tile_size,
|
67 |
+
img2.shape[1] // tile_size,
|
68 |
+
tile_size,
|
69 |
+
3
|
70 |
+
)
|
71 |
+
|
72 |
+
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
|
73 |
+
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
|
74 |
+
if len(img) < n_tiles:
|
75 |
+
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
|
76 |
+
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
|
77 |
+
img3 = img3[idxs]
|
78 |
+
for i in range(len(img3)):
|
79 |
+
result.append({'img':img3[i], 'idx':i})
|
80 |
+
return result, n_tiles_with_info >= n_tiles
|
81 |
+
|
82 |
+
def getitem(img,tile_mode):
|
83 |
+
sub_imgs=False
|
84 |
+
|
85 |
+
tiff_file = img
|
86 |
+
image = skimage.io.MultiImage(tiff_file)[0]
|
87 |
+
tiles, OK = get_tiles(image, tile_mode)
|
88 |
+
|
89 |
+
idxes=n_tiles
|
90 |
+
idxes = np.asarray(idxes) + n_tiles if sub_imgs else idxes
|
91 |
+
|
92 |
+
n_row_tiles = int(np.sqrt(n_tiles))
|
93 |
+
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
|
94 |
+
for h in range(n_row_tiles):
|
95 |
+
for w in range(n_row_tiles):
|
96 |
+
i = h * n_row_tiles + w
|
97 |
+
|
98 |
+
if len(tiles) > idxes[i]:
|
99 |
+
this_img = tiles[idxes[i]]['img']
|
100 |
+
else:
|
101 |
+
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
|
102 |
+
this_img = 255 - this_img
|
103 |
+
h1 = h * image_size
|
104 |
+
w1 = w * image_size
|
105 |
+
images[h1:h1+image_size, w1:w1+image_size] = this_img
|
106 |
+
|
107 |
+
# images = 255 - images
|
108 |
+
images = images.astype(np.float32)
|
109 |
+
images /= 255
|
110 |
+
images = images.transpose(2, 0, 1)
|
111 |
+
|
112 |
+
return torch.tensor(images)
|
113 |
+
def predict_label(im):
|
114 |
+
data1=getitem(im,0)
|
115 |
+
data2=getitem(im,2)
|
116 |
+
LOGITS=[]
|
117 |
+
LOGITS2=[]
|
118 |
+
with torch.no_grad():
|
119 |
+
data1 = data1.to(device)
|
120 |
+
logits = models[0](data1)
|
121 |
+
LOGITS.append(logits)
|
122 |
+
|
123 |
+
|
124 |
+
data = data.to(device)
|
125 |
+
logits = models[0](data)
|
126 |
+
LOGITS2.append(logits)
|
127 |
+
|
128 |
+
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
|
129 |
+
PREDS = LOGITS.sum(1).round().numpy()
|
130 |
+
return PREDS
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
def classify_images(im):
|
136 |
+
pred,idx,probs=predict_label(im)
|
137 |
+
s='Your submitted case has Prostate cancer of ISUP Grade '+pred
|
138 |
+
return s
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
img=gr.Image(source="upload")
|
143 |
+
label=gr.Label()
|
144 |
+
examples=["5.tiff","6.tiff","7.tiff"]
|
145 |
+
|
146 |
+
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model presiction",fn=classify_images,inputs=img,outputs=label,examples=examples)
|
147 |
+
intf.launch(inline=False)
|
base_implementation
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7a819d3a42ded27c07ad1e2c3490c6ab31a75ce0f9abf07e1ca42dbeab9e956
|
3 |
+
size 16297850
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pytorch-lightning
|
4 |
+
efficientnet-pytorch
|
5 |
+
numpy
|
6 |
+
pandas
|
7 |
+
scikit-learn
|
8 |
+
opencv-python
|
9 |
+
scikit-image
|
10 |
+
albumentations
|
11 |
+
Pillow
|
12 |
+
matplotlib
|
13 |
+
imagecodecs
|