cciausu97 commited on
Commit
fd4bfe4
·
verified ·
1 Parent(s): 82b653b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -43
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  from streamlit_drawable_canvas import st_canvas
3
  from streamlit_image_coordinates import streamlit_image_coordinates
@@ -18,9 +20,45 @@ import nibabel as nib
18
  import tempfile
19
 
20
  print('script run')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  #############################################
23
  st.session_state.option = None
 
 
 
 
 
 
24
  # init session_state
25
  if 'option' not in st.session_state:
26
  st.session_state.option = None
@@ -61,36 +99,11 @@ if 'running' not in st.session_state:
61
  if 'transparency' not in st.session_state:
62
  st.session_state.transparency = 0.25
63
 
64
- #download IDC data cases
65
- client = index.IDCClient()
66
- #define serieUIDs to download
67
- serieUIDs_list = ["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
68
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
69
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
70
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
71
- "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"]
72
-
73
- #download series and convert to .nii.gz
74
- samples_root_dir = "model/asset/idc_samples"
75
- if os.path.exists(samples_root_dir): shutil.rmtree(samples_root_dir)
76
- os.mkdir(samples_root_dir)
77
- for idx, serieUID_ddl in enumerate(serieUIDs_list):
78
- sample_dcm_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_dcm")
79
- sample_nii_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_nii")
80
- for dir in [sample_dcm_dir, sample_nii_dir]:
81
- if os.path.exists(dir): shutil.rmtree(dir)
82
- os.mkdir(dir)
83
- client.download_from_selection(seriesInstanceUID=\
84
- serieUID_ddl, \
85
- downloadDir=sample_dcm_dir)
86
- subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
87
- "-f", "IDC_%i", "-g", "y", sample_dcm_dir])
88
- case_list = glob.glob(os.path.join(samples_root_dir, "*nii/*.nii.gz"))
89
  # case_list = [
90
- # 'model/asset/FLARE22_Tr_0002_0000.nii.gz',
91
- # 'model/asset/FLARE22_Tr_0005_0000.nii.gz',
92
- # 'model/asset/FLARE22_Tr_0034_0000.nii.gz',
93
- # 'model/asset/FLARE22_Tr_0045_0000.nii.gz'
94
  # ]
95
 
96
  #############################################
@@ -219,7 +232,7 @@ else:
219
  top, bottom = st.select_slider(
220
  'Top and bottom of box',
221
  options=range(0, 325),
222
- value=(0, 324),
223
  disabled=st.session_state.running
224
  )
225
  st.session_state.rectangle_3Dbox[0] = top
@@ -230,19 +243,19 @@ else:
230
  st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
231
 
232
  with col_image1:
233
-
234
  image_z_array = image_3D[selected_index_z]
235
 
236
  preds_z_array = None
237
  if st.session_state.preds_3D is not None:
238
  preds_z_array = st.session_state.preds_3D[selected_index_z]
239
-
240
  image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
241
-
242
-
243
  if st.session_state.use_point_prompt:
244
  value_xy = streamlit_image_coordinates(image_z, width=325)
245
-
246
  if value_xy is not None:
247
  point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
248
  if len(st.session_state.points) >= 3:
@@ -282,16 +295,16 @@ else:
282
 
283
  with col_image2:
284
  image_y_array = image_3D[:, selected_index_y, :]
285
-
286
  preds_y_array = None
287
  if st.session_state.preds_3D is not None:
288
  preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
289
-
290
  image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
291
-
292
  if st.session_state.use_point_prompt:
293
  value_yz = streamlit_image_coordinates(image_y, width=325)
294
-
295
  if value_yz is not None:
296
  point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
297
  if len(st.session_state.points) >= 3:
@@ -332,8 +345,8 @@ with col2:
332
  original_shape = st.session_state.data_item['ori_shape']
333
  pred_array = st.session_state.preds_3D_ori
334
  original_array = np.zeros(original_shape)
335
- original_array[foreground_start_coord[0]:foreground_end_coord[0],
336
- foreground_start_coord[1]:foreground_end_coord[1],
337
  foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
338
 
339
  original_array = original_array.transpose(2, 1, 0)
@@ -356,8 +369,8 @@ with col3:
356
  if st.button(run_button_name, type="primary", use_container_width=True,
357
  disabled=(
358
  st.session_state.data_item is None or
359
- (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
360
- st.session_state.irregular_box or
361
  st.session_state.running
362
  )):
363
  st.session_state.running = True
 
1
+ %%writefile app.py
2
+
3
  import streamlit as st
4
  from streamlit_drawable_canvas import st_canvas
5
  from streamlit_image_coordinates import streamlit_image_coordinates
 
20
  import tempfile
21
 
22
  print('script run')
23
+ #further improvement
24
+ #decorator singletion or use cache data class
25
+ # https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton
26
+ # https://docs.streamlit.io/develop/concepts/architecture/caching
27
+ def download_idc_data():
28
+ #download IDC data cases
29
+ client = index.IDCClient()
30
+ #define serieUIDs to download
31
+ serieUIDs_list = ["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
32
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
33
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
34
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
35
+ "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"]
36
+
37
+ #download series and convert to .nii.gz
38
+ samples_root_dir = "/content/SegVolIDCTest/model/asset/idc_samples"
39
+ if not os.path.exists(samples_root_dir) : os.mkdir(samples_root_dir)
40
+ for idx, serieUID_ddl in enumerate(serieUIDs_list):
41
+ sample_dcm_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_dcm")
42
+ sample_nii_dir = os.path.join(samples_root_dir, f"ddl_series{idx}_nii")
43
+ for dir in [sample_dcm_dir, sample_nii_dir]:
44
+ if os.path.exists(dir):
45
+ shutil.rmtree(dir)
46
+ os.makedirs(dir)
47
+ client.download_from_selection(seriesInstanceUID=\
48
+ serieUID_ddl, \
49
+ downloadDir=sample_dcm_dir)
50
+ subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
51
+ "-f", "IDC_%i", "-g", "y", sample_dcm_dir])
52
+ return glob.glob(os.path.join(samples_root_dir, "*nii/*.nii.gz"))
53
 
54
  #############################################
55
  st.session_state.option = None
56
+
57
+ if 'idc_data' not in st.session_state:
58
+ case_list = download_idc_data()
59
+ st.session_state.idc_data = True
60
+ else:
61
+ case_list = glob.glob("/content/SegVolIDCTest/model/asset/idc_samples/*nii/*.nii.gz")
62
  # init session_state
63
  if 'option' not in st.session_state:
64
  st.session_state.option = None
 
99
  if 'transparency' not in st.session_state:
100
  st.session_state.transparency = 0.25
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  # case_list = [
103
+ # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0002_0000.nii.gz',
104
+ # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0005_0000.nii.gz',
105
+ # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0034_0000.nii.gz',
106
+ # '/content/SegVolIDCTest/model/asset/FLARE22_Tr_0045_0000.nii.gz'
107
  # ]
108
 
109
  #############################################
 
232
  top, bottom = st.select_slider(
233
  'Top and bottom of box',
234
  options=range(0, 325),
235
+ value=(0, 324),
236
  disabled=st.session_state.running
237
  )
238
  st.session_state.rectangle_3Dbox[0] = top
 
243
  st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
244
 
245
  with col_image1:
246
+
247
  image_z_array = image_3D[selected_index_z]
248
 
249
  preds_z_array = None
250
  if st.session_state.preds_3D is not None:
251
  preds_z_array = st.session_state.preds_3D[selected_index_z]
252
+
253
  image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
254
+
255
+
256
  if st.session_state.use_point_prompt:
257
  value_xy = streamlit_image_coordinates(image_z, width=325)
258
+
259
  if value_xy is not None:
260
  point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
261
  if len(st.session_state.points) >= 3:
 
295
 
296
  with col_image2:
297
  image_y_array = image_3D[:, selected_index_y, :]
298
+
299
  preds_y_array = None
300
  if st.session_state.preds_3D is not None:
301
  preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
302
+
303
  image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
304
+
305
  if st.session_state.use_point_prompt:
306
  value_yz = streamlit_image_coordinates(image_y, width=325)
307
+
308
  if value_yz is not None:
309
  point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
310
  if len(st.session_state.points) >= 3:
 
345
  original_shape = st.session_state.data_item['ori_shape']
346
  pred_array = st.session_state.preds_3D_ori
347
  original_array = np.zeros(original_shape)
348
+ original_array[foreground_start_coord[0]:foreground_end_coord[0],
349
+ foreground_start_coord[1]:foreground_end_coord[1],
350
  foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
351
 
352
  original_array = original_array.transpose(2, 1, 0)
 
369
  if st.button(run_button_name, type="primary", use_container_width=True,
370
  disabled=(
371
  st.session_state.data_item is None or
372
+ (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
373
+ st.session_state.irregular_box or
374
  st.session_state.running
375
  )):
376
  st.session_state.running = True