Spaces:
Build error
Build error
File size: 4,234 Bytes
65a733b 14167ad 65a733b 14167ad 65a733b c7af872 3b059ad 64ba3c6 3b059ad c7af872 3b059ad c7af872 64ba3c6 65a733b 44fed4a 65a733b 64ba3c6 539be96 65a733b 64ba3c6 65a733b 64ba3c6 65a733b c7af872 65a733b d32f7b5 c7af872 64ba3c6 65a733b 14167ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
NormalizeIntensityd,
Activationsd,
AsDiscreted,
ScaleIntensityd,
)
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
title = "Segment Brain Tumors with MONAI!"
description = """
## Brain Tumor Segmentation 🧠
A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data.
The model is trained to segment 3 nested subregions of primary brain tumors (gliomas): the "enhancing tumor" (ET), the "tumor core" (TC), the "whole tumor" (WT) based on 4 aligned input MRI scans (T1c, T1, T2, FLAIR).
- The ET is described by areas that show hyper intensity in T1c when compared to T1, but also when compared to "healthy" white matter in T1c.
- The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (fluid-filled) and the non-enhancing (solid) parts of the tumor.
- The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edema (ED), which is typically depicted by hyper-intense signal in FLAIR.
## To run 🚀
Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm)
## Disclaimer ⚠️
This is an example, not to be used for diagnostic purposes.
## References 👀
[1] Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
[2] Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
[3] Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""
#examples = 'examples/'
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/brats_mri_segmentation_v0.1.0',
load_ts_module=True,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
post_transforms = Compose(
[
Activationsd(keys='pred', sigmoid=True),
AsDiscreted(keys='pred', threshold=0.5),
ScaleIntensityd(keys='image', minv=0., maxv=1.)
]
)
def predict(input_file, z_axis, model=model, device=device):
data = {'image': [input_file.name]}
data = preproc_transforms(data)
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)
data['pred'] = inferer(inputs=inputs[None,...], network=model)
data = post_transforms(data)
input_image = data['image'].numpy()
pred_image = data['pred'].cpu().detach().numpy()
input_t1c_image = input_image[0, :, :, z_axis]
#input_t1_image = input_image[1, :, :, z_axis]
#input_t2_image = input_image[2, :, :, z_axis]
#input_flair_image = input_image[3, :, :, z_axis]
pred_tc_image = pred_image[0, 0, :, :, z_axis]
#pred_et_image = pred_image[0, 1, :, :, z_axis]
#pred_wt_image = pred_image[0, 2, :, :, z_axis]
return input_t1c_image, pred_tc_image, z_axis
iface = gr.Interface(
fn=predict,
inputs=[
gr.File(label='Nifti file'),
gr.Slider(0, 200, label='z-axis', value=100)
],
outputs=[
gr.Image(label='input image'),
gr.Image(label='segmentation'),
gr.Slider(0, 200, label='z-axis', value=100)],
title=title,
description=description,
#examples=examples,
)
iface.launch()
|