cciausu97 commited on
Commit
e54db1d
·
verified ·
1 Parent(s): 3cf80c9

update app.py -- added filtering by SeriesInstanceUID

Browse files
Files changed (1) hide show
  1. app.py +274 -271
app.py CHANGED
@@ -22,106 +22,102 @@ print('script run')
22
  #decorator singletion or use cache data class
23
  # https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton
24
  # https://docs.streamlit.io/develop/concepts/architecture/caching
25
- def download_idc_data():
26
- #download IDC data cases
27
- client = index.IDCClient()
28
- #define serieUIDs to download
29
- serieUIDs_list = ["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
30
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
31
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
32
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
33
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"]
34
-
35
- #download series and convert to .nii.gz
36
- samples_root_dir = "model/asset/idc_samples"
37
- if not os.path.exists(samples_root_dir) : os.mkdir(samples_root_dir)
38
- for idx, serieUID_ddl in enumerate(serieUIDs_list):
39
- sample_dcm_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_dcm")
40
- sample_nii_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_nii")
41
- for dir in [sample_dcm_dir, sample_nii_dir]:
42
- if os.path.exists(dir):
43
- shutil.rmtree(dir)
44
- os.makedirs(dir)
45
- client.download_from_selection(seriesInstanceUID=\
46
- serieUID_ddl, \
47
- downloadDir=sample_dcm_dir)
48
- subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
49
- "-f", "IDC_%i", "-g", "y", sample_dcm_dir])
50
- return glob.glob(os.path.join(samples_root_dir, "*nii/*.nii.gz"))
51
 
52
  #############################################
53
  st.session_state.option = None
54
 
55
  if 'idc_data' not in st.session_state:
56
- case_list = download_idc_data()
57
- st.session_state.idc_data = True
 
 
 
 
 
58
  else:
59
- case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz")
 
 
60
  # init session_state
61
  if 'option' not in st.session_state:
62
- st.session_state.option = None
63
  if 'text_prompt' not in st.session_state:
64
- st.session_state.text_prompt = None
65
-
66
  if 'reset_demo_case' not in st.session_state:
67
- st.session_state.reset_demo_case = False
68
 
69
  if 'preds_3D' not in st.session_state:
70
- st.session_state.preds_3D = None
71
- st.session_state.preds_3D_ori = None
72
 
73
  if 'data_item' not in st.session_state:
74
- st.session_state.data_item = None
75
 
76
  if 'points' not in st.session_state:
77
- st.session_state.points = []
78
 
79
  if 'use_text_prompt' not in st.session_state:
80
- st.session_state.use_text_prompt = False
 
 
 
81
 
82
  if 'use_point_prompt' not in st.session_state:
83
- st.session_state.use_point_prompt = False
84
 
85
  if 'use_box_prompt' not in st.session_state:
86
- st.session_state.use_box_prompt = False
87
 
88
  if 'rectangle_3Dbox' not in st.session_state:
89
- st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
90
 
91
  if 'irregular_box' not in st.session_state:
92
- st.session_state.irregular_box = False
93
 
94
  if 'running' not in st.session_state:
95
- st.session_state.running = False
96
 
97
  if 'transparency' not in st.session_state:
98
- st.session_state.transparency = 0.25
99
-
100
- # case_list = [
101
- # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0002_0000.nii.gz',
102
- # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0005_0000.nii.gz',
103
- # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0034_0000.nii.gz',
104
- # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0045_0000.nii.gz'
105
- # ]
106
-
107
  #############################################
108
 
109
  #############################################
110
  # reset functions
111
  def clear_prompts():
112
- st.session_state.points = []
113
- st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
114
 
115
  def reset_demo_case():
116
- st.session_state.data_item = None
117
- st.session_state.reset_demo_case = True
118
- clear_prompts()
 
119
 
120
  def clear_file():
121
- st.session_state.option = None
122
- process_ct_gt.clear()
123
- reset_demo_case()
124
- clear_prompts()
 
125
 
126
  #############################################
127
 
@@ -130,252 +126,259 @@ st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
130
  github_col, arxive_col = st.columns(2)
131
 
132
  with github_col:
133
- st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
134
 
135
  with arxive_col:
136
- st.write('Paper:https://arxiv.org/abs/2311.13385')
137
 
138
 
139
  # modify demo case here
140
  demo_type = st.radio(
141
- "Demo case source",
142
- ["Select", "Upload"],
143
- on_change=clear_file
144
- )
145
-
146
- if demo_type=="Select":
147
- uploaded_file = st.selectbox(
148
- "Select a demo case",
149
- case_list,
150
- index=None,
151
- placeholder="Select a demo case...",
152
- on_change=reset_demo_case
153
- )
154
  else:
155
- uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
156
-
157
- st.session_state.option = uploaded_file
 
 
 
 
 
 
 
 
158
 
159
  if st.session_state.option is not None and \
160
- st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
161
 
162
- st.session_state.data_item = process_ct_gt(st.session_state.option)
163
- st.session_state.reset_demo_case = False
164
- st.session_state.preds_3D = None
165
- st.session_state.preds_3D_ori = None
166
 
167
  prompt_col1, prompt_col2 = st.columns(2)
168
 
169
  with prompt_col1:
170
- st.session_state.use_text_prompt = st.toggle('Sematic prompt')
171
- text_prompt_type = st.radio(
172
- "Sematic prompt type",
173
- ["Predefined", "Custom"],
174
- disabled=(not st.session_state.use_text_prompt)
175
- )
176
- if text_prompt_type == "Predefined":
177
- pre_text = st.selectbox(
178
- "Predefined anatomical category:",
179
- ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
180
- index=None,
181
- disabled=(not st.session_state.use_text_prompt)
182
- )
183
- else:
184
- pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
185
- disabled=(not st.session_state.use_text_prompt))
186
- if pre_text is None or len(pre_text) > 0:
187
- st.session_state.text_prompt = pre_text
188
- else:
189
- st.session_state.text_prompt = None
190
 
191
 
192
  with prompt_col2:
193
- spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
194
- spatial_prompt = st.radio(
195
- "Spatial prompt type",
196
- ["Point prompt", "Box prompt"],
197
- on_change=clear_prompts,
198
- disabled=(not spatial_prompt_on))
199
- st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
200
 
201
  if spatial_prompt == "Point prompt":
202
- st.session_state.use_point_prompt = True
203
- st.session_state.use_box_prompt = False
204
  elif spatial_prompt == "Box prompt":
205
- st.session_state.use_box_prompt = True
206
- st.session_state.use_point_prompt = False
207
  else:
208
- st.session_state.use_point_prompt = False
209
- st.session_state.use_box_prompt = False
210
 
211
  if not spatial_prompt_on:
212
- st.session_state.use_point_prompt = False
213
- st.session_state.use_box_prompt = False
214
 
215
  if not st.session_state.use_text_prompt:
216
- st.session_state.text_prompt = None
217
 
218
  if st.session_state.option is None:
219
- st.write('please select demo case first')
220
  else:
221
- image_3D = st.session_state.data_item['z_image'][0].numpy()
222
- col_control1, col_control2 = st.columns(2)
223
-
224
- with col_control1:
225
- selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
226
-
227
- with col_control2:
228
- selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
229
- if st.session_state.use_box_prompt:
230
- top, bottom = st.select_slider(
231
- 'Top and bottom of box',
232
- options=range(0, 325),
233
- value=(0, 324),
234
- disabled=st.session_state.running
235
- )
236
- st.session_state.rectangle_3Dbox[0] = top
237
- st.session_state.rectangle_3Dbox[3] = bottom
238
- col_image1, col_image2 = st.columns(2)
239
-
240
- if st.session_state.preds_3D is not None:
241
- st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
242
-
243
- with col_image1:
244
-
245
- image_z_array = image_3D[selected_index_z]
246
-
247
- preds_z_array = None
248
- if st.session_state.preds_3D is not None:
249
- preds_z_array = st.session_state.preds_3D[selected_index_z]
250
-
251
- image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
252
-
253
-
254
- if st.session_state.use_point_prompt:
255
- value_xy = streamlit_image_coordinates(image_z, width=325)
256
-
257
- if value_xy is not None:
258
- point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
259
- if len(st.session_state.points) >= 3:
260
- st.warning('Max point num is 3', icon="⚠️")
261
- elif point_ax_xy not in st.session_state.points:
262
- st.session_state.points.append(point_ax_xy)
263
- print('point_ax_xy add rerun')
264
- st.rerun()
265
- elif st.session_state.use_box_prompt:
266
- canvas_result_xy = st_canvas(
267
- fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
268
- stroke_width=3,
269
- stroke_color='#2909F1',
270
- background_image=image_z,
271
- update_streamlit=True,
272
- height=325,
273
- width=325,
274
- drawing_mode='transform',
275
- point_display_radius=0,
276
- key="canvas_xy",
277
- initial_drawing=initial_rectangle,
278
- display_toolbar=True
279
- )
280
- try:
281
- print(canvas_result_xy.json_data['objects'][0]['angle'])
282
- if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
283
- st.warning('Rotating is undefined behavior', icon="⚠️")
284
- st.session_state.irregular_box = True
285
- else:
286
- st.session_state.irregular_box = False
287
- reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
288
- except:
289
- print('exception')
290
- pass
291
- else:
292
- st.image(image_z, use_column_width=False)
293
-
294
- with col_image2:
295
- image_y_array = image_3D[:, selected_index_y, :]
296
-
297
- preds_y_array = None
298
- if st.session_state.preds_3D is not None:
299
- preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
300
-
301
- image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
302
-
303
- if st.session_state.use_point_prompt:
304
- value_yz = streamlit_image_coordinates(image_y, width=325)
305
-
306
- if value_yz is not None:
307
- point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
308
- if len(st.session_state.points) >= 3:
309
- st.warning('Max point num is 3', icon="⚠️")
310
- elif point_ax_xz not in st.session_state.points:
311
- st.session_state.points.append(point_ax_xz)
312
- print('point_ax_xz add rerun')
313
- st.rerun()
314
- elif st.session_state.use_box_prompt:
315
- if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
316
- draw = ImageDraw.Draw(image_y)
317
- #rectangle xz view (upper-left and lower-right)
318
- rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
319
- (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
320
- # Draw the rectangle on the image
321
- draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
322
- st.image(image_y, use_column_width=False)
323
- else:
324
- st.image(image_y, use_column_width=False)
325
 
326
 
327
  col1, col2, col3 = st.columns(3)
328
 
329
  with col1:
330
- if st.button("Clear", use_container_width=True,
331
- disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
332
- clear_prompts()
333
- st.session_state.preds_3D = None
334
- st.session_state.preds_3D_ori = None
335
- st.rerun()
336
 
337
  with col2:
338
- img_nii = None
339
- if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
340
- meta_dict = st.session_state.data_item['meta']
341
- foreground_start_coord = st.session_state.data_item['foreground_start_coord']
342
- foreground_end_coord = st.session_state.data_item['foreground_end_coord']
343
- original_shape = st.session_state.data_item['ori_shape']
344
- pred_array = st.session_state.preds_3D_ori
345
- original_array = np.zeros(original_shape)
346
- original_array[foreground_start_coord[0]:foreground_end_coord[0],
347
- foreground_start_coord[1]:foreground_end_coord[1],
348
- foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
349
-
350
- original_array = original_array.transpose(2, 1, 0)
351
- img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
352
-
353
- with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
354
- nib.save(img_nii, tmpfile.name)
355
- with open(tmpfile.name, "rb") as f:
356
- bytes_data = f.read()
357
- st.download_button(
358
- label="Download result(.nii.gz)",
359
- data=bytes_data,
360
- file_name="segvol_preds.nii.gz",
361
- mime="application/octet-stream",
362
- disabled=img_nii is None
363
- )
364
 
365
  with col3:
366
- run_button_name = 'Run'if not st.session_state.running else 'Running'
367
- if st.button(run_button_name, type="primary", use_container_width=True,
368
- disabled=(
369
- st.session_state.data_item is None or
370
- (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
371
- st.session_state.irregular_box or
372
- st.session_state.running
373
- )):
374
- st.session_state.running = True
375
- st.rerun()
376
 
377
  if st.session_state.running:
378
- st.session_state.running = False
379
- with st.status("Running...", expanded=False) as status:
380
- run()
381
- st.rerun()
 
22
  #decorator singletion or use cache data class
23
  # https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton
24
  # https://docs.streamlit.io/develop/concepts/architecture/caching
25
+ def download_idc_data_serieUID(serieUID_lst, output_folder):
26
+ #download IDC data cases
27
+ client = index.IDCClient()
28
+ #define serieUIDs to download
29
+ #download series and convert to .nii.gz
30
+ if os.path.exists(output_folder):
31
+ shutil.rmtree(output_folder)
32
+ os.makedirs(output_folder)
33
+ for idx, serieUID_ddl in enumerate(serieUID_lst):
34
+ sample_dcm_dir = os.path.join(output_folder, f"ddl_series{idx}_dcm")
35
+ sample_nii_dir = os.path.join(output_folder, f"ddl_series{idx}_nii")
36
+ for dir in [sample_dcm_dir, sample_nii_dir]:
37
+ if os.path.exists(dir):
38
+ shutil.rmtree(dir)
39
+ os.makedirs(dir)
40
+ client.download_from_selection(seriesInstanceUID=serieUID_ddl, downloadDir=sample_dcm_dir)
41
+ subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
42
+ "-f", "IDC_%i", "-g", "y", sample_dcm_dir])
43
+ return glob.glob(os.path.join(output_folder, "*nii/*.nii.gz"))
 
 
 
 
 
 
 
44
 
45
  #############################################
46
  st.session_state.option = None
47
 
48
  if 'idc_data' not in st.session_state:
49
+ case_list = download_idc_data_serieUID(serieUID_lst=["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
50
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
51
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
52
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
53
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"],
54
+ output_folder="model/asset/idc_samples")
55
+ st.session_state.idc_data = True
56
  else:
57
+ case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz")
58
+ if 'idc_serieUID_sample' not in st.session_state:
59
+ st.session_state.idc_serieUID_sample = None
60
  # init session_state
61
  if 'option' not in st.session_state:
62
+ st.session_state.option = None
63
  if 'text_prompt' not in st.session_state:
64
+ st.session_state.text_prompt = None
 
65
  if 'reset_demo_case' not in st.session_state:
66
+ st.session_state.reset_demo_case = False
67
 
68
  if 'preds_3D' not in st.session_state:
69
+ st.session_state.preds_3D = None
70
+ st.session_state.preds_3D_ori = None
71
 
72
  if 'data_item' not in st.session_state:
73
+ st.session_state.data_item = None
74
 
75
  if 'points' not in st.session_state:
76
+ st.session_state.points = []
77
 
78
  if 'use_text_prompt' not in st.session_state:
79
+ st.session_state.use_text_prompt = False
80
+
81
+ if 'use_text_serieUID' not in st.session_state:
82
+ st.session_state.use_text_serieUID = False
83
 
84
  if 'use_point_prompt' not in st.session_state:
85
+ st.session_state.use_point_prompt = False
86
 
87
  if 'use_box_prompt' not in st.session_state:
88
+ st.session_state.use_box_prompt = False
89
 
90
  if 'rectangle_3Dbox' not in st.session_state:
91
+ st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
92
 
93
  if 'irregular_box' not in st.session_state:
94
+ st.session_state.irregular_box = False
95
 
96
  if 'running' not in st.session_state:
97
+ st.session_state.running = False
98
 
99
  if 'transparency' not in st.session_state:
100
+ st.session_state.transparency = 0.25
 
 
 
 
 
 
 
 
101
  #############################################
102
 
103
  #############################################
104
  # reset functions
105
  def clear_prompts():
106
+ st.session_state.points = []
107
+ st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
108
 
109
  def reset_demo_case():
110
+ st.session_state.data_item = None
111
+ st.session_state.idc_serieUID_sample = None
112
+ st.session_state.reset_demo_case = True
113
+ clear_prompts()
114
 
115
  def clear_file():
116
+ st.session_state.option = None
117
+ st.session_state.idc_serieUID_sample = None
118
+ process_ct_gt.clear()
119
+ reset_demo_case()
120
+ clear_prompts()
121
 
122
  #############################################
123
 
 
126
  github_col, arxive_col = st.columns(2)
127
 
128
  with github_col:
129
+ st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
130
 
131
  with arxive_col:
132
+ st.write('Paper:https://arxiv.org/abs/2311.13385')
133
 
134
 
135
  # modify demo case here
136
  demo_type = st.radio(
137
+ "Demo case source",
138
+ ["Select an IDC demo case from tcga_lihc collection", "Filter by DICOM SerieUID"],
139
+ on_change=clear_file
140
+ )
141
+
142
+ if demo_type=="Select an IDC demo case from tcga_lihc collection":
143
+ uploaded_file = st.selectbox(
144
+ "Select a demo case",
145
+ case_list,
146
+ index=None,
147
+ placeholder="Select a demo case...",
148
+ on_change=reset_demo_case)
 
149
  else:
150
+ with st.form("Filter by DICOM SerieUID"):
151
+ uploaded_serieUID = st.text_input("Enter a DICOM SeriesInstanceUID", value=None)
152
+ submitted = st.form_submit_button("Submit", on_click=clear_prompts)
153
+ if submitted:
154
+ st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(uploaded_serieUID).strip()], "model/asset/idc_serieUID_sample")[0]
155
+ # st.session_state.option = uploaded_file
156
+ uploaded_file = st.session_state.idc_serieUID_sample
157
+ else:
158
+ uploaded_file = st.session_state.idc_serieUID_sample
159
+
160
+ st.session_state.option = uploaded_file
161
 
162
  if st.session_state.option is not None and \
163
+ st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
164
 
165
+ st.session_state.data_item = process_ct_gt(st.session_state.option)
166
+ st.session_state.reset_demo_case = False
167
+ st.session_state.preds_3D = None
168
+ st.session_state.preds_3D_ori = None
169
 
170
  prompt_col1, prompt_col2 = st.columns(2)
171
 
172
  with prompt_col1:
173
+ st.session_state.use_text_prompt = st.toggle('Sematic prompt')
174
+ text_prompt_type = st.radio(
175
+ "Sematic prompt type",
176
+ ["Predefined", "Custom"],
177
+ disabled=(not st.session_state.use_text_prompt)
178
+ )
179
+ if text_prompt_type == "Predefined":
180
+ pre_text = st.selectbox(
181
+ "Predefined anatomical category:",
182
+ ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
183
+ index=None,
184
+ disabled=(not st.session_state.use_text_prompt)
185
+ )
186
+ else:
187
+ pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
188
+ disabled=(not st.session_state.use_text_prompt))
189
+ if pre_text is None or len(pre_text) > 0:
190
+ st.session_state.text_prompt = pre_text
191
+ else:
192
+ st.session_state.text_prompt = None
193
 
194
 
195
  with prompt_col2:
196
+ spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
197
+ spatial_prompt = st.radio(
198
+ "Spatial prompt type",
199
+ ["Point prompt", "Box prompt"],
200
+ on_change=clear_prompts,
201
+ disabled=(not spatial_prompt_on))
202
+ st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
203
 
204
  if spatial_prompt == "Point prompt":
205
+ st.session_state.use_point_prompt = True
206
+ st.session_state.use_box_prompt = False
207
  elif spatial_prompt == "Box prompt":
208
+ st.session_state.use_box_prompt = True
209
+ st.session_state.use_point_prompt = False
210
  else:
211
+ st.session_state.use_point_prompt = False
212
+ st.session_state.use_box_prompt = False
213
 
214
  if not spatial_prompt_on:
215
+ st.session_state.use_point_prompt = False
216
+ st.session_state.use_box_prompt = False
217
 
218
  if not st.session_state.use_text_prompt:
219
+ st.session_state.text_prompt = None
220
 
221
  if st.session_state.option is None:
222
+ st.write('please select demo case first')
223
  else:
224
+ image_3D = st.session_state.data_item['z_image'][0].numpy()
225
+ col_control1, col_control2 = st.columns(2)
226
+
227
+ with col_control1:
228
+ selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
229
+
230
+ with col_control2:
231
+ selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
232
+ if st.session_state.use_box_prompt:
233
+ top, bottom = st.select_slider(
234
+ 'Top and bottom of box',
235
+ options=range(0, 325),
236
+ value=(0, 324),
237
+ disabled=st.session_state.running
238
+ )
239
+ st.session_state.rectangle_3Dbox[0] = top
240
+ st.session_state.rectangle_3Dbox[3] = bottom
241
+ col_image1, col_image2 = st.columns(2)
242
+
243
+ if st.session_state.preds_3D is not None:
244
+ st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
245
+
246
+ with col_image1:
247
+
248
+ image_z_array = image_3D[selected_index_z]
249
+
250
+ preds_z_array = None
251
+ if st.session_state.preds_3D is not None:
252
+ preds_z_array = st.session_state.preds_3D[selected_index_z]
253
+
254
+ image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
255
+
256
+
257
+ if st.session_state.use_point_prompt:
258
+ value_xy = streamlit_image_coordinates(image_z, width=325)
259
+
260
+ if value_xy is not None:
261
+ point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
262
+ if len(st.session_state.points) >= 3:
263
+ st.warning('Max point num is 3', icon="??")
264
+ elif point_ax_xy not in st.session_state.points:
265
+ st.session_state.points.append(point_ax_xy)
266
+ print('point_ax_xy add rerun')
267
+ st.rerun()
268
+ elif st.session_state.use_box_prompt:
269
+ canvas_result_xy = st_canvas(
270
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
271
+ stroke_width=3,
272
+ stroke_color='#2909F1',
273
+ background_image=image_z,
274
+ update_streamlit=True,
275
+ height=325,
276
+ width=325,
277
+ drawing_mode='transform',
278
+ point_display_radius=0,
279
+ key="canvas_xy",
280
+ initial_drawing=initial_rectangle,
281
+ display_toolbar=True
282
+ )
283
+ try:
284
+ print(canvas_result_xy.json_data['objects'][0]['angle'])
285
+ if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
286
+ st.warning('Rotating is undefined behavior', icon="??")
287
+ st.session_state.irregular_box = True
288
+ else:
289
+ st.session_state.irregular_box = False
290
+ reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
291
+ except:
292
+ print('exception')
293
+ pass
294
+ else:
295
+ st.image(image_z, use_column_width=False)
296
+
297
+ with col_image2:
298
+ image_y_array = image_3D[:, selected_index_y, :]
299
+
300
+ preds_y_array = None
301
+ if st.session_state.preds_3D is not None:
302
+ preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
303
+
304
+ image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
305
+
306
+ if st.session_state.use_point_prompt:
307
+ value_yz = streamlit_image_coordinates(image_y, width=325)
308
+
309
+ if value_yz is not None:
310
+ point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
311
+ if len(st.session_state.points) >= 3:
312
+ st.warning('Max point num is 3', icon="??")
313
+ elif point_ax_xz not in st.session_state.points:
314
+ st.session_state.points.append(point_ax_xz)
315
+ print('point_ax_xz add rerun')
316
+ st.rerun()
317
+ elif st.session_state.use_box_prompt:
318
+ if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
319
+ draw = ImageDraw.Draw(image_y)
320
+ #rectangle xz view (upper-left and lower-right)
321
+ rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
322
+ (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
323
+ # Draw the rectangle on the image
324
+ draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
325
+ st.image(image_y, use_column_width=False)
326
+ else:
327
+ st.image(image_y, use_column_width=False)
328
 
329
 
330
  col1, col2, col3 = st.columns(3)
331
 
332
  with col1:
333
+ if st.button("Clear", use_container_width=True,
334
+ disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
335
+ clear_prompts()
336
+ st.session_state.preds_3D = None
337
+ st.session_state.preds_3D_ori = None
338
+ st.rerun()
339
 
340
  with col2:
341
+ img_nii = None
342
+ if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
343
+ meta_dict = st.session_state.data_item['meta']
344
+ foreground_start_coord = st.session_state.data_item['foreground_start_coord']
345
+ foreground_end_coord = st.session_state.data_item['foreground_end_coord']
346
+ original_shape = st.session_state.data_item['ori_shape']
347
+ pred_array = st.session_state.preds_3D_ori
348
+ original_array = np.zeros(original_shape)
349
+ original_array[foreground_start_coord[0]:foreground_end_coord[0],
350
+ foreground_start_coord[1]:foreground_end_coord[1],
351
+ foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
352
+
353
+ original_array = original_array.transpose(2, 1, 0)
354
+ img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
355
+
356
+ with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
357
+ nib.save(img_nii, tmpfile.name)
358
+ with open(tmpfile.name, "rb") as f:
359
+ bytes_data = f.read()
360
+ st.download_button(
361
+ label="Download result(.nii.gz)",
362
+ data=bytes_data,
363
+ file_name="segvol_preds.nii.gz",
364
+ mime="application/octet-stream",
365
+ disabled=img_nii is None
366
+ )
367
 
368
  with col3:
369
+ run_button_name = 'Run'if not st.session_state.running else 'Running'
370
+ if st.button(run_button_name, type="primary", use_container_width=True,
371
+ disabled=(
372
+ st.session_state.data_item is None or
373
+ (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
374
+ st.session_state.irregular_box or
375
+ st.session_state.running
376
+ )):
377
+ st.session_state.running = True
378
+ st.rerun()
379
 
380
  if st.session_state.running:
381
+ st.session_state.running = False
382
+ with st.status("Running...", expanded=False) as status:
383
+ run()
384
+ st.rerun()