rmm commited on
Commit
00bdefd
·
1 Parent(s): b384db4

feat: implementation of FSM, and invokation for first phases

Browse files

- fsm implementation uses the `transitions` package.
- added unique keys to the input forms, so can check when all are filled
- included a basic viz/feedback on the state

requirements.txt CHANGED
@@ -10,7 +10,8 @@ streamlit_folium==0.23.1
10
 
11
  # backend
12
  datasets==3.0.2
13
-
 
14
 
15
  # running ML models
16
 
 
10
 
11
  # backend
12
  datasets==3.0.2
13
+ ## FSM
14
+ transitions==0.9.2
15
 
16
  # running ML models
17
 
src/input/input_handling.py CHANGED
@@ -30,6 +30,43 @@ spoof_metadata = {
30
  "time": None,
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def setup_input(
34
  viewcontainer: DeltaGenerator=None,
35
  _allowed_image_types: list=None, ) -> InputObservation:
@@ -66,7 +103,8 @@ def setup_input(
66
  uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
67
  observations = {}
68
  images = {}
69
- image_hashes =[]
 
70
  if uploaded_files is not None:
71
  for file in uploaded_files:
72
 
@@ -76,6 +114,7 @@ def setup_input(
76
  # load image using cv2 format, so it is compatible with the ML models
77
  file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
78
  filename = file.name
 
79
  image = cv2.imdecode(file_bytes, 1)
80
  # Extract and display image date-time
81
  image_datetime = None # For storing date-time from image
@@ -84,12 +123,18 @@ def setup_input(
84
 
85
 
86
  # 3. Latitude Entry Box
87
- latitude = viewcontainer.text_input("Latitude for "+filename, spoof_metadata.get('latitude', ""))
 
 
 
88
  if latitude and not is_valid_number(latitude):
89
  viewcontainer.error("Please enter a valid latitude (numerical only).")
90
  m_logger.error(f"Invalid latitude entered: {latitude}.")
91
  # 4. Longitude Entry Box
92
- longitude = viewcontainer.text_input("Longitude for "+filename, spoof_metadata.get('longitude', ""))
 
 
 
93
  if longitude and not is_valid_number(longitude):
94
  viewcontainer.error("Please enter a valid longitude (numerical only).")
95
  m_logger.error(f"Invalid latitude entered: {latitude}.")
@@ -118,4 +163,6 @@ def setup_input(
118
  st.session_state.files = uploaded_files
119
  st.session_state.observations = observations
120
  st.session_state.image_hashes = image_hashes
 
 
121
 
 
30
  "time": None,
31
  }
32
 
33
+ def check_inputs_are_set(empty_ok:bool=False, debug:bool=False) -> bool:
34
+ """
35
+ Checks if all expected inputs have been entered
36
+
37
+ Implementation: via the Streamlit session state.
38
+
39
+ Args:
40
+ empty_ok (bool): If True, returns True if no inputs are set. Default is False.
41
+ debug (bool): If True, prints and logs the status of each expected input key. Default is False.
42
+ Returns:
43
+ bool: True if all expected input keys are set, False otherwise.
44
+ """
45
+ filenames = st.session_state.image_filenames
46
+ if len(filenames) == 0:
47
+ return empty_ok
48
+
49
+
50
+
51
+ exp_input_key_stubs = ["input_latitude", "input_longitude"]
52
+ #exp_input_key_stubs = ["input_latitude", "input_longitude", "input_author_email", "input_date", "input_time", "input_image_selector"]
53
+ vals = []
54
+ for image_filename in filenames:
55
+ for stub in exp_input_key_stubs:
56
+ key = f"{stub}_{image_filename}"
57
+ val = None
58
+ if key in st.session_state:
59
+ val = st.session_state[key]
60
+ vals.append(val)
61
+ if debug:
62
+ msg = f"{key:15}, {(val is not None):8}, {val}"
63
+ m_logger.debug(msg)
64
+ print(msg)
65
+
66
+ return all([v is not None for v in vals])
67
+
68
+
69
+
70
  def setup_input(
71
  viewcontainer: DeltaGenerator=None,
72
  _allowed_image_types: list=None, ) -> InputObservation:
 
103
  uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
104
  observations = {}
105
  images = {}
106
+ image_hashes = []
107
+ filenames = []
108
  if uploaded_files is not None:
109
  for file in uploaded_files:
110
 
 
114
  # load image using cv2 format, so it is compatible with the ML models
115
  file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
116
  filename = file.name
117
+ filenames.append(filename)
118
  image = cv2.imdecode(file_bytes, 1)
119
  # Extract and display image date-time
120
  image_datetime = None # For storing date-time from image
 
123
 
124
 
125
  # 3. Latitude Entry Box
126
+ latitude = viewcontainer.text_input(
127
+ "Latitude for "+filename,
128
+ spoof_metadata.get('latitude', ""),
129
+ key=f"input_latitude_{filename}")
130
  if latitude and not is_valid_number(latitude):
131
  viewcontainer.error("Please enter a valid latitude (numerical only).")
132
  m_logger.error(f"Invalid latitude entered: {latitude}.")
133
  # 4. Longitude Entry Box
134
+ longitude = viewcontainer.text_input(
135
+ "Longitude for "+filename,
136
+ spoof_metadata.get('longitude', ""),
137
+ key=f"input_longitude_{filename}")
138
  if longitude and not is_valid_number(longitude):
139
  viewcontainer.error("Please enter a valid longitude (numerical only).")
140
  m_logger.error(f"Invalid latitude entered: {latitude}.")
 
163
  st.session_state.files = uploaded_files
164
  st.session_state.observations = observations
165
  st.session_state.image_hashes = image_hashes
166
+ st.session_state.image_filenames = filenames
167
+
168
 
src/main.py CHANGED
@@ -15,10 +15,11 @@ disable_caching()
15
 
16
  import whale_gallery as gallery
17
  import whale_viewer as viewer
18
- from input.input_handling import setup_input
19
  from maps.alps_map import present_alps_map
20
  from maps.obs_map import present_obs_map
21
  from utils.st_logs import setup_logging, parse_log_buffer
 
22
  from classifier.classifier_image import cetacean_classify
23
  from classifier.classifier_hotdog import hotdog_classify
24
 
@@ -48,6 +49,11 @@ if "handler" not in st.session_state:
48
  if "image_hashes" not in st.session_state:
49
  st.session_state.image_hashes = []
50
 
 
 
 
 
 
51
  if "observations" not in st.session_state:
52
  st.session_state.observations = {}
53
 
@@ -69,6 +75,23 @@ if "whale_prediction1" not in st.session_state:
69
  if "tab_log" not in st.session_state:
70
  st.session_state.tab_log = None
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def main() -> None:
74
  """
@@ -102,6 +125,10 @@ def main() -> None:
102
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
103
  st.session_state.tab_log = tab_log
104
 
 
 
 
 
105
 
106
  # create a sidebar, and parse all the input (returned as `observations` object)
107
  setup_input(viewcontainer=st.sidebar)
@@ -181,14 +208,25 @@ def main() -> None:
181
 
182
 
183
  # Display submitted observation
184
- if st.sidebar.button("Validate"):
185
- # create a dictionary with the submitted observation
186
- tab_log.info(f"{st.session_state.observations}")
187
- df = pd.DataFrame(st.session_state.observations, index=[0])
188
- with tab_coords:
189
- st.table(df)
 
190
 
 
 
 
191
 
 
 
 
 
 
 
 
192
 
193
 
194
  # inside the inference tab, on button press we call the model (on huggingface hub)
@@ -240,6 +278,9 @@ def main() -> None:
240
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
241
 
242
 
 
 
 
243
 
244
  if __name__ == "__main__":
245
  main()
 
15
 
16
  import whale_gallery as gallery
17
  import whale_viewer as viewer
18
+ from input.input_handling import setup_input, check_inputs_are_set
19
  from maps.alps_map import present_alps_map
20
  from maps.obs_map import present_obs_map
21
  from utils.st_logs import setup_logging, parse_log_buffer
22
+ from utils.workflow_state import WorkflowFSM, FSM_STATES
23
  from classifier.classifier_image import cetacean_classify
24
  from classifier.classifier_hotdog import hotdog_classify
25
 
 
49
  if "image_hashes" not in st.session_state:
50
  st.session_state.image_hashes = []
51
 
52
+ # TODO: ideally just use image_hashes, but need a unique key for the ui elements
53
+ # to track the user input phase; and these are created before the hash is generated.
54
+ if "image_filenames" not in st.session_state:
55
+ st.session_state.image_filenames = []
56
+
57
  if "observations" not in st.session_state:
58
  st.session_state.observations = {}
59
 
 
75
  if "tab_log" not in st.session_state:
76
  st.session_state.tab_log = None
77
 
78
+ if "workflow_fsm" not in st.session_state:
79
+ # create and init the state machine
80
+ st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
81
+
82
+ # add progress indicator to session_state
83
+ if "progress" not in st.session_state:
84
+ with st.sidebar:
85
+ st.session_state.disp_progress = [st.empty(), st.empty()]
86
+
87
+ def refresh_progress():
88
+ with st.sidebar:
89
+ tot = st.session_state.workflow_fsm.num_states
90
+ cur_i = st.session_state.workflow_fsm.current_state_index
91
+ cur_t = st.session_state.workflow_fsm.current_state
92
+ st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
93
+ st.session_state.disp_progress[1].progress(cur_i/tot)
94
+
95
 
96
  def main() -> None:
97
  """
 
125
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
126
  st.session_state.tab_log = tab_log
127
 
128
+ refresh_progress()
129
+ # add button to sidebar, with the callback to refesh_progress
130
+ st.sidebar.button("Refresh Progress", on_click=refresh_progress)
131
+
132
 
133
  # create a sidebar, and parse all the input (returned as `observations` object)
134
  setup_input(viewcontainer=st.sidebar)
 
208
 
209
 
210
  # Display submitted observation
211
+ all_inputs_set = check_inputs_are_set(debug=True)
212
+ if not all_inputs_set:
213
+ st.sidebar.button(":gray[*Validate*]", disabled=True, help="Please fill in all fields.")
214
+
215
+ else:
216
+ if st.session_state.workflow_fsm.is_in_state('init'):
217
+ st.session_state.workflow_fsm.complete_current_state()
218
 
219
+ if st.sidebar.button("**Validate**"):
220
+ if st.session_state.workflow_fsm.is_in_state('data_entry_complete'):
221
+ st.session_state.workflow_fsm.complete_current_state()
222
 
223
+ # create a dictionary with the submitted observation
224
+ tab_log.info(f"{st.session_state.observations}")
225
+ df = pd.DataFrame(st.session_state.observations, index=[0])
226
+ with tab_coords:
227
+ st.table(df)
228
+
229
+
230
 
231
 
232
  # inside the inference tab, on button press we call the model (on huggingface hub)
 
278
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
279
 
280
 
281
+ # after all other processing, we can show the stage/state
282
+ refresh_progress()
283
+
284
 
285
  if __name__ == "__main__":
286
  main()
src/utils/workflow_state.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transitions import Machine
2
+ from typing import List
3
+
4
+ OKBLUE = '\033[94m'
5
+ OKGREEN = '\033[92m'
6
+ OKCYAN = '\033[96m'
7
+ FAIL = '\033[91m'
8
+ ENDC = '\033[0m'
9
+
10
+
11
+ FSM_STATES = ['init', 'data_entry_complete', 'data_entry_validated', 'ml_classification_started', 'ml_classification_completed', 'manual_inspection_completed', 'data_uploaded']
12
+
13
+
14
+ class WorkflowFSM:
15
+ def __init__(self, state_sequence: List[str]):
16
+ self.state_sequence = state_sequence
17
+ self.state_dict = {state: i for i, state in enumerate(state_sequence)}
18
+
19
+ # Create state machine
20
+ self.machine = Machine(
21
+ model=self,
22
+ states=state_sequence,
23
+ initial=state_sequence[0],
24
+ )
25
+
26
+ # For each state (except the last), add a completion transition to the next state
27
+ for i in range(len(state_sequence) - 1):
28
+ current_state = state_sequence[i]
29
+ next_state = state_sequence[i + 1]
30
+
31
+ self.machine.add_transition(
32
+ trigger=f'complete_{current_state}',
33
+ source=current_state,
34
+ dest=next_state,
35
+ conditions=[f'is_in_{current_state}']
36
+ )
37
+
38
+ # Dynamically add a condition method for each state
39
+ setattr(self, f'is_in_{current_state}',
40
+ lambda s=current_state: self.is_in_state(s))
41
+
42
+ # Add callbacks for logging
43
+ self.machine.before_state_change = self._log_transition
44
+ self.machine.after_state_change = self._post_transition
45
+
46
+ def is_in_state(self, state_name: str) -> bool:
47
+ """Check if we're currently in the specified state"""
48
+ return self.state == state_name
49
+
50
+ def complete_current_state(self) -> bool:
51
+ """
52
+ Signal that the current state is complete.
53
+ Returns True if state transition occurred, False otherwise.
54
+ """
55
+ current_state = self.state
56
+ trigger_name = f'complete_{current_state}'
57
+
58
+ if hasattr(self, trigger_name):
59
+ try:
60
+ trigger_func = getattr(self, trigger_name)
61
+ trigger_func()
62
+ return True
63
+ except:
64
+ return False
65
+ return False
66
+
67
+ @property
68
+ def current_state(self) -> str:
69
+ """Get the current state name"""
70
+ return self.state
71
+
72
+ @property
73
+ def current_state_index(self) -> int:
74
+ """Get the current state index"""
75
+ return self.state_dict[self.state]
76
+
77
+ @property
78
+ def num_states(self) -> int:
79
+ return len(self.state_sequence)
80
+
81
+
82
+ def _log_transition(self):
83
+ # TODO: use logger, not printing.
84
+ self._cprint(f"[FSM] -> Transitioning from {self.current_state}")
85
+
86
+ def _post_transition(self):
87
+ # TODO: use logger, not printing.
88
+ self._cprint(f"[FSM] -| Transitioned to {self.current_state}")
89
+
90
+ def _cprint(self, msg:str, color:str=OKCYAN):
91
+ """Print colored message"""
92
+ print(f"{color}{msg}{ENDC}")