Upload 3 files
Browse files- .gitattributes +1 -0
- app.py +147 -0
- base_implementation +3 -0
- requirements.txt +13 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
base_implementation filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import skimage.io
|
| 4 |
+
import numpy as np # linear algebra
|
| 5 |
+
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from efficientnet_pytorch import model as enet
|
| 13 |
+
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from tqdm import tqdm_notebook as tqdm
|
| 16 |
+
tile_size = 256
|
| 17 |
+
image_size = 256
|
| 18 |
+
n_tiles = 36
|
| 19 |
+
batch_size = 8
|
| 20 |
+
num_workers = 4
|
| 21 |
+
class enetv2(nn.Module):
|
| 22 |
+
def __init__(self, backbone, out_dim):
|
| 23 |
+
super(enetv2, self).__init__()
|
| 24 |
+
self.enet = enet.EfficientNet.from_name(backbone)
|
| 25 |
+
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
|
| 26 |
+
self.enet._fc = nn.Identity()
|
| 27 |
+
|
| 28 |
+
def extract(self, x):
|
| 29 |
+
return self.enet(x)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x = self.extract(x)
|
| 33 |
+
x = self.myfc(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_models(model_files):
|
| 38 |
+
models = []
|
| 39 |
+
for model_f in model_files:
|
| 40 |
+
model_f = os.path.join(model_dir, model_f)
|
| 41 |
+
backbone = 'efficientnet-b0'
|
| 42 |
+
model = enetv2(backbone, out_dim=5)
|
| 43 |
+
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
|
| 44 |
+
model.eval()
|
| 45 |
+
model.to(device)
|
| 46 |
+
models.append(model)
|
| 47 |
+
print(f'{model_f} loaded!')
|
| 48 |
+
return models
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
model_files = [
|
| 52 |
+
'base_implementation'
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
models = load_models(model_files)
|
| 56 |
+
|
| 57 |
+
def get_tiles(img, mode=0):
|
| 58 |
+
result = []
|
| 59 |
+
h, w, c = img.shape
|
| 60 |
+
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
|
| 61 |
+
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)
|
| 62 |
+
|
| 63 |
+
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
|
| 64 |
+
img3 = img2.reshape(
|
| 65 |
+
img2.shape[0] // tile_size,
|
| 66 |
+
tile_size,
|
| 67 |
+
img2.shape[1] // tile_size,
|
| 68 |
+
tile_size,
|
| 69 |
+
3
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
|
| 73 |
+
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
|
| 74 |
+
if len(img) < n_tiles:
|
| 75 |
+
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
|
| 76 |
+
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
|
| 77 |
+
img3 = img3[idxs]
|
| 78 |
+
for i in range(len(img3)):
|
| 79 |
+
result.append({'img':img3[i], 'idx':i})
|
| 80 |
+
return result, n_tiles_with_info >= n_tiles
|
| 81 |
+
|
| 82 |
+
def getitem(img,tile_mode):
|
| 83 |
+
sub_imgs=False
|
| 84 |
+
|
| 85 |
+
tiff_file = img
|
| 86 |
+
image = skimage.io.MultiImage(tiff_file)[0]
|
| 87 |
+
tiles, OK = get_tiles(image, tile_mode)
|
| 88 |
+
|
| 89 |
+
idxes=n_tiles
|
| 90 |
+
idxes = np.asarray(idxes) + n_tiles if sub_imgs else idxes
|
| 91 |
+
|
| 92 |
+
n_row_tiles = int(np.sqrt(n_tiles))
|
| 93 |
+
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
|
| 94 |
+
for h in range(n_row_tiles):
|
| 95 |
+
for w in range(n_row_tiles):
|
| 96 |
+
i = h * n_row_tiles + w
|
| 97 |
+
|
| 98 |
+
if len(tiles) > idxes[i]:
|
| 99 |
+
this_img = tiles[idxes[i]]['img']
|
| 100 |
+
else:
|
| 101 |
+
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
|
| 102 |
+
this_img = 255 - this_img
|
| 103 |
+
h1 = h * image_size
|
| 104 |
+
w1 = w * image_size
|
| 105 |
+
images[h1:h1+image_size, w1:w1+image_size] = this_img
|
| 106 |
+
|
| 107 |
+
# images = 255 - images
|
| 108 |
+
images = images.astype(np.float32)
|
| 109 |
+
images /= 255
|
| 110 |
+
images = images.transpose(2, 0, 1)
|
| 111 |
+
|
| 112 |
+
return torch.tensor(images)
|
| 113 |
+
def predict_label(im):
|
| 114 |
+
data1=getitem(im,0)
|
| 115 |
+
data2=getitem(im,2)
|
| 116 |
+
LOGITS=[]
|
| 117 |
+
LOGITS2=[]
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
data1 = data1.to(device)
|
| 120 |
+
logits = models[0](data1)
|
| 121 |
+
LOGITS.append(logits)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
data = data.to(device)
|
| 125 |
+
logits = models[0](data)
|
| 126 |
+
LOGITS2.append(logits)
|
| 127 |
+
|
| 128 |
+
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
|
| 129 |
+
PREDS = LOGITS.sum(1).round().numpy()
|
| 130 |
+
return PREDS
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def classify_images(im):
|
| 136 |
+
pred,idx,probs=predict_label(im)
|
| 137 |
+
s='Your submitted case has Prostate cancer of ISUP Grade '+pred
|
| 138 |
+
return s
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
img=gr.Image(source="upload")
|
| 143 |
+
label=gr.Label()
|
| 144 |
+
examples=["5.tiff","6.tiff","7.tiff"]
|
| 145 |
+
|
| 146 |
+
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model presiction",fn=classify_images,inputs=img,outputs=label,examples=examples)
|
| 147 |
+
intf.launch(inline=False)
|
base_implementation
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7a819d3a42ded27c07ad1e2c3490c6ab31a75ce0f9abf07e1ca42dbeab9e956
|
| 3 |
+
size 16297850
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
pytorch-lightning
|
| 4 |
+
efficientnet-pytorch
|
| 5 |
+
numpy
|
| 6 |
+
pandas
|
| 7 |
+
scikit-learn
|
| 8 |
+
opencv-python
|
| 9 |
+
scikit-image
|
| 10 |
+
albumentations
|
| 11 |
+
Pillow
|
| 12 |
+
matplotlib
|
| 13 |
+
imagecodecs
|