File size: 4,625 Bytes
babc795 bbb2c43 babc795 d2225db babc795 60d718c babc795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import sys
import skimage.io
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import model as enet
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
tile_size = 256
image_size = 256
n_tiles = 36
batch_size = 8
num_workers = 4
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class enetv2(nn.Module):
def __init__(self, backbone, out_dim):
super(enetv2, self).__init__()
self.enet = enet.EfficientNet.from_name(backbone)
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
return self.enet(x)
def forward(self, x):
x = self.extract(x)
x = self.myfc(x)
return x
def load_models(model_files):
models = []
for model_f in model_files:
model_f = os.path.join( model_f)
backbone = 'efficientnet-b0'
model = enetv2(backbone, out_dim=5)
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
model.eval()
model.to(device)
models.append(model)
print(f'{model_f} loaded!')
return models
model_files = [
'base_implementation'
]
models = load_models(model_files)
def get_tiles(img, mode=0):
result = []
h, w, c = img.shape
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
img3 = img2.reshape(
img2.shape[0] // tile_size,
tile_size,
img2.shape[1] // tile_size,
tile_size,
3
)
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
if len(img) < n_tiles:
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
img3 = img3[idxs]
for i in range(len(img3)):
result.append({'img':img3[i], 'idx':i})
return result, n_tiles_with_info >= n_tiles
def getitem(img,tile_mode):
sub_imgs=False
tiff_file = img
image = skimage.io.MultiImage(tiff_file)[0]
tiles, OK = get_tiles(image, tile_mode)
idxes=n_tiles
idxes = np.asarray(idxes) + n_tiles if sub_imgs else idxes
n_row_tiles = int(np.sqrt(n_tiles))
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
for h in range(n_row_tiles):
for w in range(n_row_tiles):
i = h * n_row_tiles + w
if len(tiles) > idxes[i]:
this_img = tiles[idxes[i]]['img']
else:
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
this_img = 255 - this_img
h1 = h * image_size
w1 = w * image_size
images[h1:h1+image_size, w1:w1+image_size] = this_img
# images = 255 - images
images = images.astype(np.float32)
images /= 255
images = images.transpose(2, 0, 1)
return torch.tensor(images)
def predict_label(im):
data1=getitem(im,0)
data2=getitem(im,2)
LOGITS=[]
LOGITS2=[]
with torch.no_grad():
data1 = data1.to(device)
logits = models[0](data1)
LOGITS.append(logits)
data = data.to(device)
logits = models[0](data)
LOGITS2.append(logits)
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
PREDS = LOGITS.sum(1).round().numpy()
return PREDS
def classify_images(im):
pred,idx,probs=predict_label(im)
s='Your submitted case has Prostate cancer of ISUP Grade '+pred
return s
img=gr.Image(label="Upload Image", type="filepath")
label=gr.Label()
examples=["5.tiff","6.tiff","7.tiff"]
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model presiction",fn=classify_images,inputs=img,outputs=label,examples=examples)
intf.launch(inline=False) |