SegVolOnIDC / app.py
cciausu97's picture
Update app.py
ffd4b8e verified
raw
history blame
17.2 kB
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from streamlit_image_coordinates import streamlit_image_coordinates
from idc_index import index
import os
import glob
import shutil
import dcm2niix
import subprocess
import random
import base64
from model.data_process.demo_data_process import process_ct_gt
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import monai.transforms as transforms
from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
import nibabel as nib
import tempfile
print('script run')
#further improvement
#decorator singletion or use cache data class
# https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton
# https://docs.streamlit.io/develop/concepts/architecture/caching
def download_idc_data_serieUID(serieUID_lst, output_folder):
#download IDC data cases
client = index.IDCClient()
#define serieUIDs to download
#download series and convert to .nii.gz
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
os.makedirs(output_folder)
for idx, serieUID_ddl in enumerate(serieUID_lst):
sample_dcm_dir = os.path.join(output_folder, f"ddl_series{idx}_dcm")
sample_nii_dir = os.path.join(output_folder, f"ddl_series{idx}_nii")
for dir in [sample_dcm_dir, sample_nii_dir]:
if os.path.exists(dir):
shutil.rmtree(dir)
os.makedirs(dir)
client.download_from_selection(seriesInstanceUID=serieUID_ddl, downloadDir=sample_dcm_dir)
subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
"-f", "IDC_%i", "-g", "y", sample_dcm_dir])
return glob.glob(os.path.join(output_folder, "*nii/*.nii.gz"))
def get_random_sample_idc_from_bodypart(bodypart_selected):
client = index.IDCClient()
# body_parts = client.index[(client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']> '100')]['BodyPartExamined'].unique()
matching_series_list = client.index[client.index['Modality'].isin(["CT"]) \
& (client.index['BodyPartExamined'] == bodypart_selected) & \
(client.index['instanceCount']> '100')]['SeriesInstanceUID'].values
# select random series from the list
random_series_uid = random.choice(matching_series_list)
random_series_viewer_url = client.get_viewer_URL(random_series_uid)
return random_series_uid, random_series_viewer_url
def retrieve_idc_index_body_parts():
idc_client = index.IDCClient()
body_parts = idc_client.index[(idc_client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']< '150')]['BodyPartExamined'].unique()
return body_parts
#############################################
st.session_state.option = None
if 'idc_data' not in st.session_state:
case_list = download_idc_data_serieUID(serieUID_lst=["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"],
output_folder="model/asset/idc_samples")
# case_list = []
st.session_state.idc_data = True
else:
case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz")
# if 'idc_index_body_part' not in st.session_state:
# body_part_list = retrieve_idc_index_body_parts()
# st.session_state.idc_index_body_part = True
# else:
# body_part_list = [""]
# if 'init_idc_client' not in st.session_state:
# st.session_state.idc_client = index.IDCClient()
# if 'idc_bodypart_selected' not in st.session_state:
# st.session_state.idc_bodypart_selected = False
if 'idc_serieUID_sample' not in st.session_state:
st.session_state.idc_serieUID_sample = None
# init session_state
if 'option' not in st.session_state:
st.session_state.option = None
if 'text_prompt' not in st.session_state:
st.session_state.text_prompt = None
if 'reset_demo_case' not in st.session_state:
st.session_state.reset_demo_case = False
if 'preds_3D' not in st.session_state:
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
if 'data_item' not in st.session_state:
st.session_state.data_item = None
if 'points' not in st.session_state:
st.session_state.points = []
if 'use_text_prompt' not in st.session_state:
st.session_state.use_text_prompt = False
if 'use_text_serieUID' not in st.session_state:
st.session_state.use_text_serieUID = False
if 'use_point_prompt' not in st.session_state:
st.session_state.use_point_prompt = False
if 'use_box_prompt' not in st.session_state:
st.session_state.use_box_prompt = False
if 'rectangle_3Dbox' not in st.session_state:
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
if 'irregular_box' not in st.session_state:
st.session_state.irregular_box = False
if 'running' not in st.session_state:
st.session_state.running = False
if 'transparency' not in st.session_state:
st.session_state.transparency = 0.25
#############################################
#############################################
# reset functions
def clear_prompts():
st.session_state.points = []
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
def reset_demo_case():
st.session_state.data_item = None
st.session_state.idc_serieUID_sample = None
st.session_state.reset_demo_case = True
st.session_state.idc_bodypart_selected = False
clear_prompts()
def clear_file():
st.session_state.option = None
st.session_state.idc_serieUID_sample = None
st.session_state.idc_bodypart_selected = False
process_ct_gt.clear()
reset_demo_case()
clear_prompts()
#############################################
st.image("idc_serieUID_selection.gif")
st.write("What is left is to paste the copied SeriesInstanceUID as showed above into the filter by DICOM SeriesInstanceUID box.")
st.write("Below is an overview of the SegVol method and authors acknowledgement.")
st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
github_col, arxive_col = st.columns(2)
with github_col:
st.write('SegVol GitHub repo:https://github.com/BAAI-DCAI/SegVol')
with arxive_col:
st.write('SegVol Paper:https://arxiv.org/abs/2311.13385')
# modify demo case here
demo_type = st.radio(
"Demo case source",
["Select an IDC demo case from tcga_lihc collection",
"Filter by DICOM SeriesInstanceUID",
"Random sampling based on BodyPartExamined"],
on_change=clear_file
)
if demo_type=="Select an IDC demo case from tcga_lihc collection":
uploaded_file = st.selectbox(
"Select a demo case",
case_list,
index=None,
placeholder="Select a demo case...",
on_change=reset_demo_case)
elif demo_type=="Filter by DICOM SeriesInstanceUID":
with st.form("Filter by DICOM SeriesInstanceUID"):
uploaded_serieUID = st.text_input("Enter a DICOM SeriesInstanceUID", value=None)
submitted = st.form_submit_button("Submit", on_click=clear_prompts)
if submitted:
st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(uploaded_serieUID).strip()], "model/asset/idc_serieUID_sample")[0]
# st.session_state.option = uploaded_file
uploaded_file = st.session_state.idc_serieUID_sample
else:
uploaded_file = st.session_state.idc_serieUID_sample
else:#elif demo_type == "Random sampling based on BodyPartExamined":
with st.form("Filter by DICOM BodyPartExamined Tag") as form_body_part:
# body_part_list = retrieve_idc_index_body_parts()
body_part_selected = st.selectbox(
"Select a bodypart to randomly sample a CT scan from",
["ABDOMEN", "LUNG", "LIVER",
"PELVIS"],
index=None,
placeholder="Select a bodypart to pick a SeriesInstanceUID from...")
submitted = st.form_submit_button("Submit", on_click=reset_demo_case)
#if st.session_state.reset_demo_case == True and body_part_selected is not None:# and st.session_state.idc_bodypart_selected == False and
if submitted:
serieUID, ohif_link = get_random_sample_idc_from_bodypart(body_part_selected)
for i in range(0,5):
if os.path.exists("model/asset/idc_serieUID_random_sample"):
shutil.rmtree("model/asset/idc_serieUID_random_sample")
st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(serieUID)], "model/asset/idc_serieUID_random_sample")[0]
path_file = glob.glob(f"model/asset/idc_serieUID_random_sample/ddl_series0_nii/*.nii.gz")
if path_file and len(path_file) == 1:
break
else:
print("serieUID NOT FILLING BASIC REQs --> MORE THAN 1 NII FILE OR NO NII FILE")
# st.write(f"SeriesInstanceUID randomly sampled from chosen BodyPartExamined : {random_series_uid}")
# st.write(f"OHIF URL of selected sample : {random_series_viewer_url}")
# st.session_state.idc_bodypart_selected = True
uploaded_file = st.session_state.idc_serieUID_sample
else:
uploaded_file = st.session_state.idc_serieUID_sample
st.session_state.option = uploaded_file
if st.session_state.option is not None and \
st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
st.session_state.data_item = process_ct_gt(st.session_state.option)
st.session_state.reset_demo_case = False
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
prompt_col1, prompt_col2 = st.columns(2)
with prompt_col1:
st.session_state.use_text_prompt = st.toggle('Sematic prompt')
text_prompt_type = st.radio(
"Sematic prompt type",
["Predefined", "Custom"],
disabled=(not st.session_state.use_text_prompt)
)
if text_prompt_type == "Predefined":
pre_text = st.selectbox(
"Predefined anatomical category:",
['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
index=None,
disabled=(not st.session_state.use_text_prompt)
)
else:
pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
disabled=(not st.session_state.use_text_prompt))
if pre_text is None or len(pre_text) > 0:
st.session_state.text_prompt = pre_text
else:
st.session_state.text_prompt = None
with prompt_col2:
spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
spatial_prompt = st.radio(
"Spatial prompt type",
["Point prompt", "Box prompt"],
on_change=clear_prompts,
disabled=(not spatial_prompt_on))
st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
if spatial_prompt == "Point prompt":
st.session_state.use_point_prompt = True
st.session_state.use_box_prompt = False
elif spatial_prompt == "Box prompt":
st.session_state.use_box_prompt = True
st.session_state.use_point_prompt = False
else:
st.session_state.use_point_prompt = False
st.session_state.use_box_prompt = False
if not spatial_prompt_on:
st.session_state.use_point_prompt = False
st.session_state.use_box_prompt = False
if not st.session_state.use_text_prompt:
st.session_state.text_prompt = None
if st.session_state.option is None:
st.write('please select demo case first')
else:
image_3D = st.session_state.data_item['z_image'][0].numpy()
col_control1, col_control2 = st.columns(2)
with col_control1:
selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
with col_control2:
selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
if st.session_state.use_box_prompt:
top, bottom = st.select_slider(
'Top and bottom of box',
options=range(0, 325),
value=(0, 324),
disabled=st.session_state.running
)
st.session_state.rectangle_3Dbox[0] = top
st.session_state.rectangle_3Dbox[3] = bottom
col_image1, col_image2 = st.columns(2)
if st.session_state.preds_3D is not None:
st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
with col_image1:
image_z_array = image_3D[selected_index_z]
preds_z_array = None
if st.session_state.preds_3D is not None:
preds_z_array = st.session_state.preds_3D[selected_index_z]
image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
if st.session_state.use_point_prompt:
value_xy = streamlit_image_coordinates(image_z, width=325)
if value_xy is not None:
point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
if len(st.session_state.points) >= 3:
st.warning('Max point num is 3', icon="??")
elif point_ax_xy not in st.session_state.points:
st.session_state.points.append(point_ax_xy)
print('point_ax_xy add rerun')
st.rerun()
elif st.session_state.use_box_prompt:
canvas_result_xy = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color='#2909F1',
background_image=image_z,
update_streamlit=True,
height=325,
width=325,
drawing_mode='transform',
point_display_radius=0,
key="canvas_xy",
initial_drawing=initial_rectangle,
display_toolbar=True
)
try:
print(canvas_result_xy.json_data['objects'][0]['angle'])
if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
st.warning('Rotating is undefined behavior', icon="??")
st.session_state.irregular_box = True
else:
st.session_state.irregular_box = False
reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
except:
print('exception')
pass
else:
st.image(image_z, use_column_width=False)
with col_image2:
image_y_array = image_3D[:, selected_index_y, :]
preds_y_array = None
if st.session_state.preds_3D is not None:
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
if st.session_state.use_point_prompt:
value_yz = streamlit_image_coordinates(image_y, width=325)
if value_yz is not None:
point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
if len(st.session_state.points) >= 3:
st.warning('Max point num is 3', icon="??")
elif point_ax_xz not in st.session_state.points:
st.session_state.points.append(point_ax_xz)
print('point_ax_xz add rerun')
st.rerun()
elif st.session_state.use_box_prompt:
if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
draw = ImageDraw.Draw(image_y)
#rectangle xz view (upper-left and lower-right)
rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
(st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
# Draw the rectangle on the image
draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
st.image(image_y, use_column_width=False)
else:
st.image(image_y, use_column_width=False)
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Clear", use_container_width=True,
disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
clear_prompts()
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
st.rerun()
with col2:
img_nii = None
if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
meta_dict = st.session_state.data_item['meta']
foreground_start_coord = st.session_state.data_item['foreground_start_coord']
foreground_end_coord = st.session_state.data_item['foreground_end_coord']
original_shape = st.session_state.data_item['ori_shape']
pred_array = st.session_state.preds_3D_ori
original_array = np.zeros(original_shape)
original_array[foreground_start_coord[0]:foreground_end_coord[0],
foreground_start_coord[1]:foreground_end_coord[1],
foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
original_array = original_array.transpose(2, 1, 0)
img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
nib.save(img_nii, tmpfile.name)
with open(tmpfile.name, "rb") as f:
bytes_data = f.read()
st.download_button(
label="Download result(.nii.gz)",
data=bytes_data,
file_name="segvol_preds.nii.gz",
mime="application/octet-stream",
disabled=img_nii is None
)
with col3:
run_button_name = 'Run'if not st.session_state.running else 'Running'
if st.button(run_button_name, type="primary", use_container_width=True,
disabled=(
st.session_state.data_item is None or
(st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
st.session_state.irregular_box or
st.session_state.running
)):
st.session_state.running = True
st.rerun()
if st.session_state.running:
st.session_state.running = False
with st.status("Running...", expanded=False) as status:
run()
st.rerun()