import os import sys import skimage.io import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) import gradio as gr from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from efficientnet_pytorch import model as enet import matplotlib.pyplot as plt from tqdm import tqdm_notebook as tqdm tile_size = 256 image_size = 256 n_tiles = 36 batch_size = 8 num_workers = 4 Image.MAX_IMAGE_PIXELS = None # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class enetv2(nn.Module): def __init__(self, backbone, out_dim): super(enetv2, self).__init__() self.enet = enet.EfficientNet.from_name(backbone) self.myfc = nn.Linear(self.enet._fc.in_features, out_dim) self.enet._fc = nn.Identity() def extract(self, x): return self.enet(x) def forward(self, x): x = self.extract(x) x = self.myfc(x) return x def load_models(model_files): models = [] for model_f in model_files: model_f = os.path.join( model_f) backbone = 'efficientnet-b0' model = enetv2(backbone, out_dim=5) model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True) model.eval() model.to(device) models.append(model) print(f'{model_f} loaded!') return models model_files = [ 'base_implementation' ] models = load_models(model_files) def get_tiles(img, mode=0): result = [] h, w, c = img.shape pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2) pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2) img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255) img3 = img2.reshape( img2.shape[0] // tile_size, tile_size, img2.shape[1] // tile_size, tile_size, 3 ) img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3) n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum() if len(img) < n_tiles: img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255) idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles] img3 = img3[idxs] for i in range(len(img3)): result.append({'img':img3[i], 'idx':i}) return result, n_tiles_with_info >= n_tiles def getitem(img, tile_mode): tiff_file = img image = skimage.io.MultiImage(tiff_file)[0] tiles, OK = get_tiles(image, tile_mode) idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False) n_row_tiles = int(np.sqrt(n_tiles)) images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3)) for h in range(n_row_tiles): for w in range(n_row_tiles): i = h * n_row_tiles + w if len(tiles) > idxes[i]: this_img = tiles[idxes[i]]['img'] else: this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255 this_img = 255 - this_img h1 = h * image_size w1 = w * image_size images[h1:h1 + image_size, w1:w1 + image_size] = this_img images = images.astype(np.float32) images /= 255 images = images.transpose(2, 0, 1) # Add a batch dimension return torch.tensor(images).unsqueeze(0) def predict_label(im): data1 = getitem(im, 0) data2 = getitem(im, 2) LOGITS = [] LOGITS2 = [] with torch.no_grad(): data1 = data1.to(device) logits = models[0](data1) LOGITS.append(logits) data2 = data2.to(device) logits2 = models[0](data2) LOGITS2.append(logits2) LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2 PREDS = LOGITS.sum(1).round().numpy() return PREDS def classify_images(im): pred=predict_label(im) s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred) return s img=gr.Image(label="Upload Image", type="filepath") label=gr.Label() examples=["5.tiff","6.tiff","7.tiff"] intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples) intf.launch(inline=False)