import streamlit as st from streamlit_drawable_canvas import st_canvas from streamlit_image_coordinates import streamlit_image_coordinates from idc_index import index import os import glob import shutil import dcm2niix import subprocess import random import base64 from model.data_process.demo_data_process import process_ct_gt import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageDraw import monai.transforms as transforms from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run import nibabel as nib import tempfile print('script run') #further improvement #decorator singletion or use cache data class # https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton # https://docs.streamlit.io/develop/concepts/architecture/caching def download_idc_data_serieUID(serieUID_lst, output_folder): #download IDC data cases client = index.IDCClient() #define serieUIDs to download #download series and convert to .nii.gz if os.path.exists(output_folder): shutil.rmtree(output_folder) os.makedirs(output_folder) for idx, serieUID_ddl in enumerate(serieUID_lst): sample_dcm_dir = os.path.join(output_folder, f"ddl_series{idx}_dcm") sample_nii_dir = os.path.join(output_folder, f"ddl_series{idx}_nii") for dir in [sample_dcm_dir, sample_nii_dir]: if os.path.exists(dir): shutil.rmtree(dir) os.makedirs(dir) client.download_from_selection(seriesInstanceUID=serieUID_ddl, downloadDir=sample_dcm_dir) subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y", "-f", "IDC_%i", "-g", "y", sample_dcm_dir]) return glob.glob(os.path.join(output_folder, "*nii/*.nii.gz")) def get_random_sample_idc_from_bodypart(bodypart_selected): client = index.IDCClient() # body_parts = client.index[(client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']> '100')]['BodyPartExamined'].unique() matching_series_list = client.index[client.index['Modality'].isin(["CT"]) \ & (client.index['BodyPartExamined'] == bodypart_selected) & \ (client.index['instanceCount']> '100')]['SeriesInstanceUID'].values # select random series from the list random_series_uid = random.choice(matching_series_list) random_series_viewer_url = client.get_viewer_URL(random_series_uid) return random_series_uid, random_series_viewer_url def retrieve_idc_index_body_parts(): idc_client = index.IDCClient() body_parts = idc_client.index[(idc_client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']< '150')]['BodyPartExamined'].unique() return body_parts ############################################# st.session_state.option = None if 'idc_data' not in st.session_state: case_list = download_idc_data_serieUID(serieUID_lst=["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644", "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045", "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850", "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147", "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"], output_folder="model/asset/idc_samples") # case_list = [] st.session_state.idc_data = True else: case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz") # if 'idc_index_body_part' not in st.session_state: # body_part_list = retrieve_idc_index_body_parts() # st.session_state.idc_index_body_part = True # else: # body_part_list = [""] # if 'init_idc_client' not in st.session_state: # st.session_state.idc_client = index.IDCClient() # if 'idc_bodypart_selected' not in st.session_state: # st.session_state.idc_bodypart_selected = False if 'idc_serieUID_sample' not in st.session_state: st.session_state.idc_serieUID_sample = None # init session_state if 'option' not in st.session_state: st.session_state.option = None if 'text_prompt' not in st.session_state: st.session_state.text_prompt = None if 'reset_demo_case' not in st.session_state: st.session_state.reset_demo_case = False if 'preds_3D' not in st.session_state: st.session_state.preds_3D = None st.session_state.preds_3D_ori = None if 'data_item' not in st.session_state: st.session_state.data_item = None if 'points' not in st.session_state: st.session_state.points = [] if 'use_text_prompt' not in st.session_state: st.session_state.use_text_prompt = False if 'use_text_serieUID' not in st.session_state: st.session_state.use_text_serieUID = False if 'use_point_prompt' not in st.session_state: st.session_state.use_point_prompt = False if 'use_box_prompt' not in st.session_state: st.session_state.use_box_prompt = False if 'rectangle_3Dbox' not in st.session_state: st.session_state.rectangle_3Dbox = [0,0,0,0,0,0] if 'irregular_box' not in st.session_state: st.session_state.irregular_box = False if 'running' not in st.session_state: st.session_state.running = False if 'transparency' not in st.session_state: st.session_state.transparency = 0.25 ############################################# ############################################# # reset functions def clear_prompts(): st.session_state.points = [] st.session_state.rectangle_3Dbox = [0,0,0,0,0,0] def reset_demo_case(): st.session_state.data_item = None st.session_state.idc_serieUID_sample = None st.session_state.reset_demo_case = True st.session_state.idc_bodypart_selected = False clear_prompts() def clear_file(): st.session_state.option = None st.session_state.idc_serieUID_sample = None st.session_state.idc_bodypart_selected = False process_ct_gt.clear() reset_demo_case() clear_prompts() ############################################# st.image("idc_serieUID_selection.gif") st.write("What is left is to paste the copied SeriesInstanceUID as showed above into the filter by DICOM SeriesInstanceUID box.") st.write("Below is an overview of the SegVol method and authors acknowledgement.") st.image(Image.open('model/asset/overview back.png'), use_column_width=True) github_col, arxive_col = st.columns(2) with github_col: st.write('SegVol GitHub repo:https://github.com/BAAI-DCAI/SegVol') with arxive_col: st.write('SegVol Paper:https://arxiv.org/abs/2311.13385') # modify demo case here demo_type = st.radio( "Demo case source", ["Select an IDC demo case from tcga_lihc collection", "Filter by DICOM SeriesInstanceUID", "Random sampling based on BodyPartExamined"], on_change=clear_file ) if demo_type=="Select an IDC demo case from tcga_lihc collection": uploaded_file = st.selectbox( "Select a demo case", case_list, index=None, placeholder="Select a demo case...", on_change=reset_demo_case) elif demo_type=="Filter by DICOM SeriesInstanceUID": with st.form("Filter by DICOM SeriesInstanceUID"): uploaded_serieUID = st.text_input("Enter a DICOM SeriesInstanceUID", value=None) submitted = st.form_submit_button("Submit", on_click=clear_prompts) if submitted: st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(uploaded_serieUID).strip()], "model/asset/idc_serieUID_sample")[0] # st.session_state.option = uploaded_file uploaded_file = st.session_state.idc_serieUID_sample else: uploaded_file = st.session_state.idc_serieUID_sample else:#elif demo_type == "Random sampling based on BodyPartExamined": with st.form("Filter by DICOM BodyPartExamined Tag") as form_body_part: # body_part_list = retrieve_idc_index_body_parts() body_part_selected = st.selectbox( "Select a bodypart to randomly sample a CT scan from", ["ABDOMEN", "LUNG", "LIVER", "PELVIS"], index=None, placeholder="Select a bodypart to pick a SeriesInstanceUID from...") submitted = st.form_submit_button("Submit", on_click=reset_demo_case) #if st.session_state.reset_demo_case == True and body_part_selected is not None:# and st.session_state.idc_bodypart_selected == False and if submitted: serieUID, ohif_link = get_random_sample_idc_from_bodypart(body_part_selected) for i in range(0,5): if os.path.exists("model/asset/idc_serieUID_random_sample"): shutil.rmtree("model/asset/idc_serieUID_random_sample") st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(serieUID)], "model/asset/idc_serieUID_random_sample")[0] path_file = glob.glob(f"model/asset/idc_serieUID_random_sample/ddl_series0_nii/*.nii.gz") if path_file and len(path_file) == 1: break else: print("serieUID NOT FILLING BASIC REQs --> MORE THAN 1 NII FILE OR NO NII FILE") # st.write(f"SeriesInstanceUID randomly sampled from chosen BodyPartExamined : {random_series_uid}") # st.write(f"OHIF URL of selected sample : {random_series_viewer_url}") # st.session_state.idc_bodypart_selected = True uploaded_file = st.session_state.idc_serieUID_sample else: uploaded_file = st.session_state.idc_serieUID_sample st.session_state.option = uploaded_file if st.session_state.option is not None and \ st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None): st.session_state.data_item = process_ct_gt(st.session_state.option) st.session_state.reset_demo_case = False st.session_state.preds_3D = None st.session_state.preds_3D_ori = None prompt_col1, prompt_col2 = st.columns(2) with prompt_col1: st.session_state.use_text_prompt = st.toggle('Sematic prompt') text_prompt_type = st.radio( "Sematic prompt type", ["Predefined", "Custom"], disabled=(not st.session_state.use_text_prompt) ) if text_prompt_type == "Predefined": pre_text = st.selectbox( "Predefined anatomical category:", ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'], index=None, disabled=(not st.session_state.use_text_prompt) ) else: pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20, disabled=(not st.session_state.use_text_prompt)) if pre_text is None or len(pre_text) > 0: st.session_state.text_prompt = pre_text else: st.session_state.text_prompt = None with prompt_col2: spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts) spatial_prompt = st.radio( "Spatial prompt type", ["Point prompt", "Box prompt"], on_change=clear_prompts, disabled=(not spatial_prompt_on)) st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in') if spatial_prompt == "Point prompt": st.session_state.use_point_prompt = True st.session_state.use_box_prompt = False elif spatial_prompt == "Box prompt": st.session_state.use_box_prompt = True st.session_state.use_point_prompt = False else: st.session_state.use_point_prompt = False st.session_state.use_box_prompt = False if not spatial_prompt_on: st.session_state.use_point_prompt = False st.session_state.use_box_prompt = False if not st.session_state.use_text_prompt: st.session_state.text_prompt = None if st.session_state.option is None: st.write('please select demo case first') else: image_3D = st.session_state.data_item['z_image'][0].numpy() col_control1, col_control2 = st.columns(2) with col_control1: selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running) with col_control2: selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running) if st.session_state.use_box_prompt: top, bottom = st.select_slider( 'Top and bottom of box', options=range(0, 325), value=(0, 324), disabled=st.session_state.running ) st.session_state.rectangle_3Dbox[0] = top st.session_state.rectangle_3Dbox[3] = bottom col_image1, col_image2 = st.columns(2) if st.session_state.preds_3D is not None: st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running) with col_image1: image_z_array = image_3D[selected_index_z] preds_z_array = None if st.session_state.preds_3D is not None: preds_z_array = st.session_state.preds_3D[selected_index_z] image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy') if st.session_state.use_point_prompt: value_xy = streamlit_image_coordinates(image_z, width=325) if value_xy is not None: point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x']) if len(st.session_state.points) >= 3: st.warning('Max point num is 3', icon="??") elif point_ax_xy not in st.session_state.points: st.session_state.points.append(point_ax_xy) print('point_ax_xy add rerun') st.rerun() elif st.session_state.use_box_prompt: canvas_result_xy = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=3, stroke_color='#2909F1', background_image=image_z, update_streamlit=True, height=325, width=325, drawing_mode='transform', point_display_radius=0, key="canvas_xy", initial_drawing=initial_rectangle, display_toolbar=True ) try: print(canvas_result_xy.json_data['objects'][0]['angle']) if canvas_result_xy.json_data['objects'][0]['angle'] != 0: st.warning('Rotating is undefined behavior', icon="??") st.session_state.irregular_box = True else: st.session_state.irregular_box = False reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy') except: print('exception') pass else: st.image(image_z, use_column_width=False) with col_image2: image_y_array = image_3D[:, selected_index_y, :] preds_y_array = None if st.session_state.preds_3D is not None: preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz') if st.session_state.use_point_prompt: value_yz = streamlit_image_coordinates(image_y, width=325) if value_yz is not None: point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x']) if len(st.session_state.points) >= 3: st.warning('Max point num is 3', icon="??") elif point_ax_xz not in st.session_state.points: st.session_state.points.append(point_ax_xz) print('point_ax_xz add rerun') st.rerun() elif st.session_state.use_box_prompt: if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]: draw = ImageDraw.Draw(image_y) #rectangle xz view (upper-left and lower-right) rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]), (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])] # Draw the rectangle on the image draw.rectangle(rectangle_coords, outline='#2909F1', width=3) st.image(image_y, use_column_width=False) else: st.image(image_y, use_column_width=False) col1, col2, col3 = st.columns(3) with col1: if st.button("Clear", use_container_width=True, disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))): clear_prompts() st.session_state.preds_3D = None st.session_state.preds_3D_ori = None st.rerun() with col2: img_nii = None if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None: meta_dict = st.session_state.data_item['meta'] foreground_start_coord = st.session_state.data_item['foreground_start_coord'] foreground_end_coord = st.session_state.data_item['foreground_end_coord'] original_shape = st.session_state.data_item['ori_shape'] pred_array = st.session_state.preds_3D_ori original_array = np.zeros(original_shape) original_array[foreground_start_coord[0]:foreground_end_coord[0], foreground_start_coord[1]:foreground_end_coord[1], foreground_start_coord[2]:foreground_end_coord[2]] = pred_array original_array = original_array.transpose(2, 1, 0) img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine']) with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: nib.save(img_nii, tmpfile.name) with open(tmpfile.name, "rb") as f: bytes_data = f.read() st.download_button( label="Download result(.nii.gz)", data=bytes_data, file_name="segvol_preds.nii.gz", mime="application/octet-stream", disabled=img_nii is None ) with col3: run_button_name = 'Run'if not st.session_state.running else 'Running' if st.button(run_button_name, type="primary", use_container_width=True, disabled=( st.session_state.data_item is None or (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or st.session_state.irregular_box or st.session_state.running )): st.session_state.running = True st.rerun() if st.session_state.running: st.session_state.running = False with st.status("Running...", expanded=False) as status: run() st.rerun()