File size: 2,436 Bytes
65a733b
14167ad
65a733b
 
 
 
 
 
 
 
 
 
 
 
14167ad
65a733b
 
 
 
 
 
 
44fed4a
65a733b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14167ad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    NormalizeIntensityd,
    Activationsd,
    AsDiscreted,
    ScaleIntensityd,
)

BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

examples = ['examples/BRATS_485.nii.gz']

model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/brats_mri_segmentation_v0.1.0',
    load_ts_module=True,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')

preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
post_transforms = Compose(
    [
        Activationsd(keys='pred', sigmoid=True),
        AsDiscreted(keys='pred', threshold=0.5),
        ScaleIntensityd(keys='image', minv=0., maxv=1.)
    ]
)

def predict(input_file, z_axis, model=model, device=device):
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)
        data['pred'] = inferer(inputs=inputs[None,...], network=model)
        data = post_transforms(data)
    
    input_image = data['image'].numpy()
    pred_image = data['pred'].cpu().detach().numpy()
    
    input_t1_image = input_image[0, :, :, z_axis]
    input_t1c_image = input_image[1, :, :, z_axis]
    input_t2_image = input_image[2, :, :, z_axis]
    input_flair_image = input_image[3, :, :, z_axis]
    
    pred_1_image = pred_image[0, 0, :, :, z_axis]
    pred_2_image = pred_image[0, 1, :, :, z_axis]
    pred_3_image = pred_image[0, 2, :, :, z_axis]
    
    return input_t1c_image, pred_1_image

iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label='Nifti file'),
        gr.Slider(0, 200, label='z-axis', value=100)
    ], 
    outputs=[
        gr.Image(label='input image'),
        gr.Image(label='segmentation')],
    title='Segment Gliomas using MONAI',
    examples=examples,
)

iface.launch()