File size: 4,572 Bytes
babc795
 
 
 
 
 
32ae2ce
babc795
 
 
 
 
 
 
 
 
 
 
 
 
6828302
 
bbb2c43
 
 
 
babc795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2225db
babc795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec54fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
babc795
ec54fd7
 
 
 
 
babc795
 
 
 
 
ec54fd7
 
 
babc795
 
 
 
 
 
 
 
3c2e1db
9e89537
babc795
 
 
 
60d718c
babc795
 
 
9e89537
babc795
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import sys
import skimage.io
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import model as enet

import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
tile_size = 256
image_size = 256
n_tiles = 36
batch_size = 8
num_workers = 4
Image.MAX_IMAGE_PIXELS = None



# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_name(backbone)
        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x
    
    
def load_models(model_files):
    models = []
    for model_f in model_files:
        model_f = os.path.join( model_f)
        backbone = 'efficientnet-b0'
        model = enetv2(backbone, out_dim=5)
        model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True)
        model.eval()
        model.to(device)
        models.append(model)
        print(f'{model_f} loaded!')
    return models


model_files = [
    'base_implementation'
]

models = load_models(model_files)

def get_tiles(img, mode=0):
        result = []
        h, w, c = img.shape
        pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
        pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

        img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
        img3 = img2.reshape(
            img2.shape[0] // tile_size,
            tile_size,
            img2.shape[1] // tile_size,
            tile_size,
            3
        )

        img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
        n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
        if len(img) < n_tiles:
            img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
        idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
        img3 = img3[idxs]
        for i in range(len(img3)):
            result.append({'img':img3[i], 'idx':i})
        return result, n_tiles_with_info >= n_tiles

def getitem(img, tile_mode):
    tiff_file = img
    image = skimage.io.MultiImage(tiff_file)[0]
    tiles, OK = get_tiles(image, tile_mode)

    idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False)

    n_row_tiles = int(np.sqrt(n_tiles))
    images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
    for h in range(n_row_tiles):
        for w in range(n_row_tiles):
            i = h * n_row_tiles + w

            if len(tiles) > idxes[i]:
                this_img = tiles[idxes[i]]['img']
            else:
                this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
            this_img = 255 - this_img
            h1 = h * image_size
            w1 = w * image_size
            images[h1:h1 + image_size, w1:w1 + image_size] = this_img

    images = images.astype(np.float32)
    images /= 255
    images = images.transpose(2, 0, 1)

    # Add a batch dimension
    return torch.tensor(images).unsqueeze(0)


def predict_label(im):
    data1 = getitem(im, 0)
    data2 = getitem(im, 2)
    LOGITS = []
    LOGITS2 = []

    with torch.no_grad():
        data1 = data1.to(device)
        logits = models[0](data1)
        LOGITS.append(logits)

        data2 = data2.to(device)
        logits2 = models[0](data2)
        LOGITS2.append(logits2)

    LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2
    PREDS = LOGITS.sum(1).round().numpy()
    return PREDS



def classify_images(im):
    pred=predict_label(im)
    s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred)
    return s



img=gr.Image(label="Upload Image", type="filepath")
label=gr.Label()
examples=["5.tiff","6.tiff","7.tiff"]

intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples)
intf.launch(inline=False)