|
import os |
|
import sys |
|
import skimage.io |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, Dataset |
|
from efficientnet_pytorch import model as enet |
|
|
|
import matplotlib.pyplot as plt |
|
from tqdm import tqdm_notebook as tqdm |
|
tile_size = 256 |
|
image_size = 256 |
|
n_tiles = 36 |
|
batch_size = 8 |
|
num_workers = 4 |
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
class enetv2(nn.Module): |
|
def __init__(self, backbone, out_dim): |
|
super(enetv2, self).__init__() |
|
self.enet = enet.EfficientNet.from_name(backbone) |
|
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim) |
|
self.enet._fc = nn.Identity() |
|
|
|
def extract(self, x): |
|
return self.enet(x) |
|
|
|
def forward(self, x): |
|
x = self.extract(x) |
|
x = self.myfc(x) |
|
return x |
|
|
|
|
|
def load_models(model_files): |
|
models = [] |
|
for model_f in model_files: |
|
model_f = os.path.join( model_f) |
|
backbone = 'efficientnet-b0' |
|
model = enetv2(backbone, out_dim=5) |
|
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True) |
|
model.eval() |
|
model.to(device) |
|
models.append(model) |
|
print(f'{model_f} loaded!') |
|
return models |
|
|
|
|
|
model_files = [ |
|
'base_implementation' |
|
] |
|
|
|
models = load_models(model_files) |
|
|
|
def get_tiles(img, mode=0): |
|
result = [] |
|
h, w, c = img.shape |
|
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2) |
|
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2) |
|
|
|
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255) |
|
img3 = img2.reshape( |
|
img2.shape[0] // tile_size, |
|
tile_size, |
|
img2.shape[1] // tile_size, |
|
tile_size, |
|
3 |
|
) |
|
|
|
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3) |
|
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum() |
|
if len(img) < n_tiles: |
|
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255) |
|
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles] |
|
img3 = img3[idxs] |
|
for i in range(len(img3)): |
|
result.append({'img':img3[i], 'idx':i}) |
|
return result, n_tiles_with_info >= n_tiles |
|
|
|
def getitem(img, tile_mode): |
|
tiff_file = img |
|
image = skimage.io.MultiImage(tiff_file)[0] |
|
tiles, OK = get_tiles(image, tile_mode) |
|
|
|
idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False) |
|
|
|
n_row_tiles = int(np.sqrt(n_tiles)) |
|
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3)) |
|
for h in range(n_row_tiles): |
|
for w in range(n_row_tiles): |
|
i = h * n_row_tiles + w |
|
|
|
if len(tiles) > idxes[i]: |
|
this_img = tiles[idxes[i]]['img'] |
|
else: |
|
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255 |
|
this_img = 255 - this_img |
|
h1 = h * image_size |
|
w1 = w * image_size |
|
images[h1:h1 + image_size, w1:w1 + image_size] = this_img |
|
|
|
images = images.astype(np.float32) |
|
images /= 255 |
|
images = images.transpose(2, 0, 1) |
|
|
|
|
|
return torch.tensor(images).unsqueeze(0) |
|
|
|
|
|
def predict_label(im): |
|
data1 = getitem(im, 0) |
|
data2 = getitem(im, 2) |
|
LOGITS = [] |
|
LOGITS2 = [] |
|
|
|
with torch.no_grad(): |
|
data1 = data1.to(device) |
|
logits = models[0](data1) |
|
LOGITS.append(logits) |
|
|
|
data2 = data2.to(device) |
|
logits2 = models[0](data2) |
|
LOGITS2.append(logits2) |
|
|
|
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2 |
|
PREDS = LOGITS.sum(1).round().numpy() |
|
return PREDS |
|
|
|
|
|
|
|
def classify_images(im): |
|
pred=predict_label(im) |
|
s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred) |
|
return s |
|
|
|
|
|
|
|
img=gr.Image(label="Upload Image", type="filepath") |
|
label=gr.Label() |
|
examples=["5.tiff","6.tiff","7.tiff"] |
|
|
|
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples) |
|
intf.launch(inline=False) |