rmm commited on
Commit
01fa6a9
·
1 Parent(s): 1311e0c

chore: reorganise code out of main

Browse files
src/input/input_handling.py CHANGED
@@ -339,6 +339,28 @@ def init_input_container_states() -> None:
339
  if "container_metadata_inputs" not in st.session_state:
340
  st.session_state.container_metadata_inputs = None
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def add_input_UI_elements() -> None:
344
  '''
 
339
  if "container_metadata_inputs" not in st.session_state:
340
  st.session_state.container_metadata_inputs = None
341
 
342
+ def init_input_data_session_states() -> None:
343
+ '''
344
+ Initialise the session state variables used in the input handling
345
+ '''
346
+
347
+ if "image_hashes" not in st.session_state:
348
+ st.session_state.image_hashes = []
349
+
350
+ # TODO: ideally just use image_hashes, but need a unique key for the ui elements
351
+ # to track the user input phase; and these are created before the hash is generated.
352
+ if "image_filenames" not in st.session_state:
353
+ st.session_state.image_filenames = []
354
+
355
+ if "observations" not in st.session_state:
356
+ st.session_state.observations = {}
357
+
358
+ if "images" not in st.session_state:
359
+ st.session_state.images = {}
360
+
361
+ if "files" not in st.session_state:
362
+ st.session_state.files = {}
363
+
364
 
365
  def add_input_UI_elements() -> None:
366
  '''
src/main.py CHANGED
@@ -17,12 +17,13 @@ disable_caching()
17
  import whale_gallery as gallery
18
  import whale_viewer as viewer
19
  from input.input_handling import setup_input, check_inputs_are_set
20
- from input.input_handling import init_input_container_states, add_input_UI_elements
21
 
22
  from maps.alps_map import present_alps_map
23
  from maps.obs_map import present_obs_map
24
  from utils.st_logs import setup_logging, parse_log_buffer
25
  from utils.workflow_state import WorkflowFSM, FSM_STATES
 
26
  #from classifier.classifier_image import cetacean_classify
27
  from classifier.classifier_image import cetacean_just_classify, cetacean_show_results_and_review, cetacean_show_results
28
 
@@ -51,22 +52,6 @@ st.set_page_config(layout="wide")
51
  if "handler" not in st.session_state:
52
  st.session_state['handler'] = setup_logging()
53
 
54
- if "image_hashes" not in st.session_state:
55
- st.session_state.image_hashes = []
56
-
57
- # TODO: ideally just use image_hashes, but need a unique key for the ui elements
58
- # to track the user input phase; and these are created before the hash is generated.
59
- if "image_filenames" not in st.session_state:
60
- st.session_state.image_filenames = []
61
-
62
- if "observations" not in st.session_state:
63
- st.session_state.observations = {}
64
-
65
- if "images" not in st.session_state:
66
- st.session_state.images = {}
67
-
68
- if "files" not in st.session_state:
69
- st.session_state.files = {}
70
 
71
  if "public_observation" not in st.session_state:
72
  st.session_state.public_observation = {}
@@ -84,22 +69,11 @@ if "workflow_fsm" not in st.session_state:
84
  # create and init the state machine
85
  st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
86
 
 
87
  init_input_container_states()
88
-
89
- def refresh_progress():
90
- with st.sidebar:
91
- tot = st.session_state.workflow_fsm.num_states - 1
92
- cur_i = st.session_state.workflow_fsm.current_state_index
93
- cur_t = st.session_state.workflow_fsm.current_state
94
- st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
95
- st.session_state.disp_progress[1].progress(cur_i/tot)
96
- # add progress indicator to session_state
97
- if "progress" not in st.session_state:
98
- with st.sidebar:
99
- st.session_state.disp_progress = [st.empty(), st.empty()]
100
- # add button to sidebar, with the callback to refesh_progress
101
- st.sidebar.button("Refresh Progress", on_click=refresh_progress)
102
 
 
103
 
104
  def dbg_show_obs_hashes():
105
  # a debug: we seem to be losing the whale classes?
 
17
  import whale_gallery as gallery
18
  import whale_viewer as viewer
19
  from input.input_handling import setup_input, check_inputs_are_set
20
+ from input.input_handling import init_input_container_states, add_input_UI_elements, init_input_data_session_states
21
 
22
  from maps.alps_map import present_alps_map
23
  from maps.obs_map import present_obs_map
24
  from utils.st_logs import setup_logging, parse_log_buffer
25
  from utils.workflow_state import WorkflowFSM, FSM_STATES
26
+ from utils.workflow_ui import refresh_progress, init_workflow_viz
27
  #from classifier.classifier_image import cetacean_classify
28
  from classifier.classifier_image import cetacean_just_classify, cetacean_show_results_and_review, cetacean_show_results
29
 
 
52
  if "handler" not in st.session_state:
53
  st.session_state['handler'] = setup_logging()
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  if "public_observation" not in st.session_state:
57
  st.session_state.public_observation = {}
 
69
  # create and init the state machine
70
  st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
71
 
72
+ init_input_data_session_states()
73
  init_input_container_states()
74
+ init_workflow_viz()
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+
77
 
78
  def dbg_show_obs_hashes():
79
  # a debug: we seem to be losing the whale classes?
src/utils/workflow_ui.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def refresh_progress():
4
+ with st.sidebar:
5
+ tot = st.session_state.workflow_fsm.num_states - 1
6
+ cur_i = st.session_state.workflow_fsm.current_state_index
7
+ cur_t = st.session_state.workflow_fsm.current_state
8
+ st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
9
+ st.session_state.disp_progress[1].progress(cur_i/tot)
10
+
11
+ def init_workflow_viz():
12
+ # add progress indicator to session_state
13
+ if "progress" not in st.session_state:
14
+ with st.sidebar:
15
+ st.session_state.disp_progress = [st.empty(), st.empty()]
16
+ # add button to sidebar, with the callback to refesh_progress
17
+ st.sidebar.button("Refresh Progress", on_click=refresh_progress)
18
+