PCaBase / app.py
DLBot's picture
Update app.py
9e89537 verified
import os
import sys
import skimage.io
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import model as enet
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
tile_size = 256
image_size = 256
n_tiles = 36
batch_size = 8
num_workers = 4
Image.MAX_IMAGE_PIXELS = None
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class enetv2(nn.Module):
def __init__(self, backbone, out_dim):
super(enetv2, self).__init__()
self.enet = enet.EfficientNet.from_name(backbone)
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
return self.enet(x)
def forward(self, x):
x = self.extract(x)
x = self.myfc(x)
return x
def load_models(model_files):
models = []
for model_f in model_files:
model_f = os.path.join( model_f)
backbone = 'efficientnet-b0'
model = enetv2(backbone, out_dim=5)
model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
model.eval()
model.to(device)
models.append(model)
print(f'{model_f} loaded!')
return models
model_files = [
'base_implementation'
]
models = load_models(model_files)
def get_tiles(img, mode=0):
result = []
h, w, c = img.shape
pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)
img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
img3 = img2.reshape(
img2.shape[0] // tile_size,
tile_size,
img2.shape[1] // tile_size,
tile_size,
3
)
img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
if len(img) < n_tiles:
img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
img3 = img3[idxs]
for i in range(len(img3)):
result.append({'img':img3[i], 'idx':i})
return result, n_tiles_with_info >= n_tiles
def getitem(img, tile_mode):
tiff_file = img
image = skimage.io.MultiImage(tiff_file)[0]
tiles, OK = get_tiles(image, tile_mode)
idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False)
n_row_tiles = int(np.sqrt(n_tiles))
images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
for h in range(n_row_tiles):
for w in range(n_row_tiles):
i = h * n_row_tiles + w
if len(tiles) > idxes[i]:
this_img = tiles[idxes[i]]['img']
else:
this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
this_img = 255 - this_img
h1 = h * image_size
w1 = w * image_size
images[h1:h1 + image_size, w1:w1 + image_size] = this_img
images = images.astype(np.float32)
images /= 255
images = images.transpose(2, 0, 1)
# Add a batch dimension
return torch.tensor(images).unsqueeze(0)
def predict_label(im):
data1 = getitem(im, 0)
data2 = getitem(im, 2)
LOGITS = []
LOGITS2 = []
with torch.no_grad():
data1 = data1.to(device)
logits = models[0](data1)
LOGITS.append(logits)
data2 = data2.to(device)
logits2 = models[0](data2)
LOGITS2.append(logits2)
LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
PREDS = LOGITS.sum(1).round().numpy()
return PREDS
def classify_images(im):
pred=predict_label(im)
s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred)
return s
img=gr.Image(label="Upload Image", type="filepath")
label=gr.Label()
examples=["5.tiff","6.tiff","7.tiff"]
intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples)
intf.launch(inline=False)