File size: 4,234 Bytes
65a733b
14167ad
65a733b
 
 
 
 
 
 
 
 
 
 
 
14167ad
65a733b
 
 
c7af872
 
3b059ad
64ba3c6
 
3b059ad
 
 
 
c7af872
3b059ad
 
 
 
 
 
 
 
 
 
 
 
 
c7af872
 
64ba3c6
65a733b
 
 
44fed4a
65a733b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ba3c6
 
539be96
 
65a733b
64ba3c6
 
 
65a733b
64ba3c6
65a733b
c7af872
65a733b
 
 
 
 
 
 
 
d32f7b5
 
c7af872
 
64ba3c6
65a733b
14167ad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    NormalizeIntensityd,
    Activationsd,
    AsDiscreted,
    ScaleIntensityd,
)

BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

title = "Segment Brain Tumors with MONAI!"
description = """
## Brain Tumor Segmentation 🧠
A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data.

The model is trained to segment 3 nested subregions of primary brain tumors (gliomas): the "enhancing tumor" (ET), the "tumor core" (TC), the "whole tumor" (WT) based on 4 aligned input MRI scans (T1c, T1, T2, FLAIR).
- The ET is described by areas that show hyper intensity in T1c when compared to T1, but also when compared to "healthy" white matter in T1c.
- The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (fluid-filled) and the non-enhancing (solid) parts of the tumor.
-  The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edema (ED), which is typically depicted by hyper-intense signal in FLAIR.

## To run 🚀

Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm)

## Disclaimer ⚠️

This is an example, not to be used for diagnostic purposes.

## References 👀

[1] Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
[2] Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
[3] Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""

#examples = 'examples/'

model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/brats_mri_segmentation_v0.1.0',
    load_ts_module=True,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')

preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
post_transforms = Compose(
    [
        Activationsd(keys='pred', sigmoid=True),
        AsDiscreted(keys='pred', threshold=0.5),
        ScaleIntensityd(keys='image', minv=0., maxv=1.)
    ]
)

def predict(input_file, z_axis, model=model, device=device):
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)
        data['pred'] = inferer(inputs=inputs[None,...], network=model)
        data = post_transforms(data)
    
    input_image = data['image'].numpy()
    pred_image = data['pred'].cpu().detach().numpy()
    
    input_t1c_image = input_image[0, :, :, z_axis]
    #input_t1_image = input_image[1, :, :, z_axis]
    #input_t2_image = input_image[2, :, :, z_axis]
    #input_flair_image = input_image[3, :, :, z_axis]
    
    pred_tc_image = pred_image[0, 0, :, :, z_axis]
    #pred_et_image = pred_image[0, 1, :, :, z_axis]
    #pred_wt_image = pred_image[0, 2, :, :, z_axis]
    
    return input_t1c_image, pred_tc_image, z_axis


iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label='Nifti file'),
        gr.Slider(0, 200, label='z-axis', value=100)
    ], 
    outputs=[
        gr.Image(label='input image'),
        gr.Image(label='segmentation'),
        gr.Slider(0, 200, label='z-axis', value=100)],
    title=title,
    description=description,
    #examples=examples,
)

iface.launch()