File size: 4,572 Bytes
babc795 32ae2ce babc795 6828302 bbb2c43 babc795 d2225db babc795 ec54fd7 babc795 ec54fd7 babc795 ec54fd7 babc795 3c2e1db 9e89537 babc795 60d718c babc795 9e89537 babc795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import sys
import skimage.io
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import model as enet
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
tile_size = 256
image_size = 256
n_tiles = 36
batch_size = 8
num_workers = 4
Image.MAX_IMAGE_PIXELS = None
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class enetv2(nn.Module):
def __init__(self, backbone, out_dim):
super(enetv2, self).__init__()
self.enet = enet.EfficientNet.from_name(backbone)
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
return self.enet(x)
def forward(self, x):
x = self.extract(x)
x = self.myfc(x)
return x
def load_models(model_files):
models = []
for model_f in model_files:
model_f = os.path.join( model_f)
backbone = 'efficientnet-b0'
model = enetv2(backbone, out_dim=5)
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
model.eval()
model.to(device)
models.append(model)
print(f'{model_f} loaded!')
return models
model_files = [
'base_implementation'
]
models = load_models(model_files)
def get_tiles(img, mode=0):
result = []
h, w, c = img.shape
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
img3 = img2.reshape(
img2.shape[0] // tile_size,
tile_size,
img2.shape[1] // tile_size,
tile_size,
3
)
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
if len(img) < n_tiles:
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
img3 = img3[idxs]
for i in range(len(img3)):
result.append({'img':img3[i], 'idx':i})
return result, n_tiles_with_info >= n_tiles
def getitem(img, tile_mode):
tiff_file = img
image = skimage.io.MultiImage(tiff_file)[0]
tiles, OK = get_tiles(image, tile_mode)
idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False)
n_row_tiles = int(np.sqrt(n_tiles))
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
for h in range(n_row_tiles):
for w in range(n_row_tiles):
i = h * n_row_tiles + w
if len(tiles) > idxes[i]:
this_img = tiles[idxes[i]]['img']
else:
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
this_img = 255 - this_img
h1 = h * image_size
w1 = w * image_size
images[h1:h1 + image_size, w1:w1 + image_size] = this_img
images = images.astype(np.float32)
images /= 255
images = images.transpose(2, 0, 1)
# Add a batch dimension
return torch.tensor(images).unsqueeze(0)
def predict_label(im):
data1 = getitem(im, 0)
data2 = getitem(im, 2)
LOGITS = []
LOGITS2 = []
with torch.no_grad():
data1 = data1.to(device)
logits = models[0](data1)
LOGITS.append(logits)
data2 = data2.to(device)
logits2 = models[0](data2)
LOGITS2.append(logits2)
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
PREDS = LOGITS.sum(1).round().numpy()
return PREDS
def classify_images(im):
pred=predict_label(im)
s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred)
return s
img=gr.Image(label="Upload Image", type="filepath")
label=gr.Label()
examples=["5.tiff","6.tiff","7.tiff"]
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples)
intf.launch(inline=False) |