ayyuce's picture
Update app.py
2623207 verified
raw
history blame
3.04 kB
import gradio as gr
import subprocess
import os
import shutil
import uuid
import zipfile
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
def run_segmentation(uploaded_file, modality):
job_id = str(uuid.uuid4())
input_filename = f"input_{job_id}.nii.gz"
output_folder = f"segmentations_{job_id}"
if isinstance(uploaded_file, str):
shutil.copy(uploaded_file, input_filename)
elif hasattr(uploaded_file, "read"):
with open(input_filename, "wb") as f:
f.write(uploaded_file.read())
else:
return "Invalid file input", None
command = ["TotalSegmentator", "-i", input_filename, "-o", output_folder]
if modality == "MR":
command.extend(["--task", "total_mr"])
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
return f"Error during segmentation: {e}", None
zip_filename = f"segmentations_{job_id}.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(output_folder):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, output_folder)
zipf.write(file_path, arcname)
seg_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('.nii.gz')]
if seg_files:
seg_file = seg_files[0]
try:
seg_img = nib.load(seg_file)
seg_data = seg_img.get_fdata()
slice_idx = seg_data.shape[2] // 2
seg_slice = seg_data[:, :, slice_idx]
plt.figure(figsize=(6, 6))
plt.imshow(seg_slice.T, cmap="gray", origin="lower")
plt.axis('off')
image_filename = f"segmentation_preview_{job_id}.png"
plt.savefig(image_filename, bbox_inches='tight')
plt.close()
except Exception as e:
print(f"Error creating preview: {e}")
image_filename = None
else:
image_filename = None
os.remove(input_filename)
shutil.rmtree(output_folder)
return zip_filename, image_filename
with gr.Blocks() as demo:
gr.Markdown("# TotalSegmentator Gradio App")
gr.Markdown(
"Upload a CT or MR image (in NIfTI format) and run segmentation using TotalSegmentator. "
"For MR images, the task flag is set accordingly. A preview of one segmentation slice will be displayed."
)
with gr.Row():
uploaded_file = gr.File(label="Upload NIfTI Image (.nii.gz)")
modality = gr.Radio(choices=["CT", "MR"], label="Select Image Modality", value="CT")
with gr.Row():
zip_output = gr.File(label="Download Segmentation Output (zip)")
preview_output = gr.Image(label="Segmentation Preview")
run_btn = gr.Button("Run Segmentation")
run_btn.click(fn=run_segmentation, inputs=[uploaded_file, modality], outputs=[zip_output, preview_output])
demo.launch()