rmm commited on
Commit
80b4be6
·
1 Parent(s): f824145

feat: first implementation of an FSM to keep track of phase

Browse files

- classifier_image gets split into multiple functions: run inference;
manual review/validation (with display of results); just display of
results
- workflow_state implements a FSM using `transitions`. It is a bit too
simple in the end, as the idea was to have a trigger that allows
moving to the next state without needing to know what that state is
called (otherwise the specification of pathways by data structs
doesn't simplify it). But it is too buggy this way, you can advance
in places that you shouldn't. So will refactor this.
- in main
- we convert some more session_state to dicts, to handle image batches
- add a simple widget to show progress through the workflow
- the main effort to stop losing progress is on the tab_inference.
lots of testing state to check what action to take. you can see
several attempts, to clean up now I understand a big bug was in
gating everything by the inference button.

requirements.txt CHANGED
@@ -10,6 +10,10 @@ streamlit_folium==0.23.1
10
 
11
  # backend
12
  datasets==3.0.2
 
 
 
 
13
 
14
 
15
  # running ML models
 
10
 
11
  # backend
12
  datasets==3.0.2
13
+ # - FSM
14
+ transitions==0.9.2
15
+ # optional, dev for the FSM (diagrams)
16
+ # pyperclip==1.9.0
17
 
18
 
19
  # running ML models
src/classifier/classifier_image.py CHANGED
@@ -11,7 +11,111 @@ from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
- def cetacean_classify(cetacean_classifier):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  files = st.session_state.files
16
  images = st.session_state.images
17
  observations = st.session_state.observations
 
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
+ # need to divide this into two functions, one for the classification and one for the display
15
+ # it is currently somewhat interleaved, not totally clear how to separate them.
16
+ # perhaps we have more stages than I realised.
17
+ # ML started, ML completed, manual review completed, data uploaded
18
+
19
+ # for now, let's implement the division between ML classification, and display+manual review.
20
+
21
+ def cetacean_classify_list(cetacean_classifier):
22
+ success = False
23
+
24
+ files = st.session_state.files
25
+ images = st.session_state.images
26
+ observations = st.session_state.observations
27
+
28
+ #batch_size, row_size, page = gridder(files)
29
+ #grid = st.columns(row_size)
30
+ #col = 0
31
+
32
+ for file in files:
33
+ key = file.name
34
+ image = images[key]
35
+
36
+ observation = observations[key].to_dict()
37
+ # run classifier model on `image`, and persistently store the output
38
+ out = cetacean_classifier(image) # get top 3 matches
39
+ st.session_state.whale_prediction1[key] = out['predictions'][0]
40
+ st.session_state.classify_whale_done[key] = True # TODO 25.01 unclear what this is for;
41
+ msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
42
+ g_logger.info(msg)
43
+
44
+ observations[key].set_top_predictions(out['predictions'][:])
45
+
46
+ st.session_state.public_observation[key] = observation #
47
+ msg = f"[D] full observation after inference: {observation}"
48
+ g_logger.debug(msg)
49
+ print(msg)
50
+
51
+ # TODO: add some mech to test if it was successful.
52
+ success = True
53
+ st.balloons()
54
+ return success
55
+
56
+ def cetacean_show_classifications():
57
+ st.write("TOP TEXT")
58
+ st.write("Reviewing the classifications :construction:")
59
+ files = st.session_state.files
60
+ images = st.session_state.images
61
+ observations = st.session_state.observations
62
+
63
+ batch_size, row_size, page = gridder(files)
64
+
65
+ grid = st.columns(row_size)
66
+ col = 0
67
+
68
+ for file in files:
69
+ key = file.name
70
+ image = images[key]
71
+
72
+ with grid[col]:
73
+ st.image(image, use_column_width=True)
74
+ observation = observations[key].to_dict()
75
+ # fetch the classification results
76
+ # run classifier model on `image`, and persistently store the output
77
+ msg = f"[D]2b classify_whale_done ({file}): {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
78
+ g_logger.info(msg)
79
+
80
+ # dropdown for selecting/overriding the species prediction
81
+ # TODO: the "it's done" flag seems to get reset when we re-load the tab. Not quite right.
82
+ if not st.session_state.classify_whale_done[key]:
83
+ #selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
84
+ # TODO: ask LV why it is in the sidebar, and not in the grid
85
+ selected_class = st.selectbox("Species", viewer.WHALE_CLASSES,
86
+ index=None, placeholder="Species not yet identified...",
87
+ disabled=True, key=f"cldd_{key}")
88
+ else:
89
+ pred1 = st.session_state.whale_prediction1[key]
90
+ # get index of pred1 from WHALE_CLASSES, none if not present
91
+ print(f"[D] pred1: {pred1}")
92
+ ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
93
+ selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
94
+
95
+ observation['predicted_class'] = selected_class
96
+ if selected_class != st.session_state.whale_prediction1[key]:
97
+ observation['class_overriden'] = selected_class
98
+
99
+ st.session_state.public_observation = observation
100
+ st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
101
+ # TODO: the metadata only fills properly if `validate` was clicked.
102
+ st.markdown(metadata2md())
103
+
104
+ msg = f"[D] full observation after inference: {observation}"
105
+ g_logger.debug(msg)
106
+ print(msg)
107
+ # TODO: add a link to more info on the model, next to the button.
108
+ whale_classes = observations[key].top_predictions
109
+ # render images for the top 3 (that is what the model api returns)
110
+ n = len(whale_classes)
111
+ st.markdown(f"Top {n} Predictions for {file.name}")
112
+ for i in range(n):
113
+ viewer.display_whale(whale_classes, i)
114
+ col = (col + 1) % row_size
115
+ return True
116
+
117
+
118
+ def cetacean_classify_and_review(cetacean_classifier):
119
  files = st.session_state.files
120
  images = st.session_state.images
121
  observations = st.session_state.observations
src/main.py CHANGED
@@ -17,8 +17,11 @@ import whale_viewer as viewer
17
  from input.input_handling import setup_input
18
  from maps.alps_map import present_alps_map
19
  from maps.obs_map import present_obs_map
 
 
20
  from utils.st_logs import setup_logging, parse_log_buffer
21
- from classifier.classifier_image import cetacean_classify
 
22
  from classifier.classifier_hotdog import hotdog_classify
23
 
24
 
@@ -57,14 +60,31 @@ if "public_observation" not in st.session_state:
57
  st.session_state.public_observation = {}
58
 
59
  if "classify_whale_done" not in st.session_state:
60
- st.session_state.classify_whale_done = False
61
 
62
  if "whale_prediction1" not in st.session_state:
63
- st.session_state.whale_prediction1 = None
64
 
65
  if "tab_log" not in st.session_state:
66
  st.session_state.tab_log = None
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def main() -> None:
70
  """
@@ -98,11 +118,15 @@ def main() -> None:
98
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
99
  st.session_state.tab_log = tab_log
100
 
 
 
 
 
101
 
102
  # create a sidebar, and parse all the input (returned as `observations` object)
103
  observations = setup_input(viewcontainer=st.sidebar)
104
 
105
-
106
  if 0:## WIP
107
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
108
  predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
@@ -118,7 +142,7 @@ def main() -> None:
118
  with tab_map:
119
  # visual structure: a couple of toggles at the top, then the map inlcuding a
120
  # dropdown for tileset selection.
121
- sw_map.add_header_text()
122
  tab_map_ui_cols = st.columns(2)
123
  with tab_map_ui_cols[0]:
124
  show_db_points = st.toggle("Show Points from DB", True)
@@ -178,9 +202,14 @@ def main() -> None:
178
 
179
  # Display submitted observation
180
  if st.sidebar.button("Validate"):
 
 
181
  # create a dictionary with the submitted observation
182
  submitted_data = observations
183
  st.session_state.observations = observations
 
 
 
184
 
185
  tab_log.info(f"{st.session_state.observations}")
186
 
@@ -202,20 +231,74 @@ def main() -> None:
202
  Once inference is complete, the top three predictions are shown.
203
  You can override the prediction by selecting a species from the dropdown.*""")
204
 
205
- if tab_inference.button("Identify with cetacean classifier"):
206
- #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
207
- cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
208
- revision=classifier_revision,
209
- trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
 
212
- if st.session_state.images is None:
213
- # TODO: cleaner design to disable the button until data input done?
214
- st.info("Please upload an image first.")
215
- else:
216
- cetacean_classify(cetacean_classifier)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
-
219
 
220
 
221
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
@@ -240,6 +323,9 @@ def main() -> None:
240
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
241
 
242
 
 
 
 
243
 
244
  if __name__ == "__main__":
245
  main()
 
17
  from input.input_handling import setup_input
18
  from maps.alps_map import present_alps_map
19
  from maps.obs_map import present_obs_map
20
+ from maps.obs_map import add_header_text as add_header_text_obs_map
21
+
22
  from utils.st_logs import setup_logging, parse_log_buffer
23
+ from utils.workflow_state import WorkflowFSM, WorkflowState
24
+ from classifier.classifier_image import cetacean_classify_and_review, cetacean_classify_list, cetacean_show_classifications
25
  from classifier.classifier_hotdog import hotdog_classify
26
 
27
 
 
60
  st.session_state.public_observation = {}
61
 
62
  if "classify_whale_done" not in st.session_state:
63
+ st.session_state.classify_whale_done = {}
64
 
65
  if "whale_prediction1" not in st.session_state:
66
+ st.session_state.whale_prediction1 = {}
67
 
68
  if "tab_log" not in st.session_state:
69
  st.session_state.tab_log = None
70
 
71
+ if "workflow_fsm" not in st.session_state:
72
+ # create and init the state machine
73
+ st.session_state.workflow_fsm = WorkflowFSM()
74
+
75
+ # add progress indicator to session_state
76
+ if "progress" not in st.session_state:
77
+ with st.sidebar:
78
+ st.session_state.disp_progress = [st.empty(), st.empty()]
79
+
80
+ def refresh_progress():
81
+ with st.sidebar:
82
+ tot = st.session_state.workflow_fsm.num_states
83
+ cur_i = st.session_state.workflow_fsm.state_number
84
+ cur_t = st.session_state.workflow_fsm.state_name
85
+ st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
86
+ st.session_state.disp_progress[1].progress(cur_i/tot)
87
+
88
 
89
  def main() -> None:
90
  """
 
118
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
119
  st.session_state.tab_log = tab_log
120
 
121
+ refresh_progress()
122
+ # add button to sidebar, with the callback to refesh_progress
123
+ st.sidebar.button("Refresh Progress", on_click=refresh_progress)
124
+
125
 
126
  # create a sidebar, and parse all the input (returned as `observations` object)
127
  observations = setup_input(viewcontainer=st.sidebar)
128
 
129
+
130
  if 0:## WIP
131
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
132
  predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
 
142
  with tab_map:
143
  # visual structure: a couple of toggles at the top, then the map inlcuding a
144
  # dropdown for tileset selection.
145
+ add_header_text_obs_map()
146
  tab_map_ui_cols = st.columns(2)
147
  with tab_map_ui_cols[0]:
148
  show_db_points = st.toggle("Show Points from DB", True)
 
202
 
203
  # Display submitted observation
204
  if st.sidebar.button("Validate"):
205
+ # TODO 25.01 - it seems unclear what validation is actually happening *after* the button click.
206
+
207
  # create a dictionary with the submitted observation
208
  submitted_data = observations
209
  st.session_state.observations = observations
210
+ # advance two steps, since the code for enabling the validate button is in a different branch right now
211
+ st.session_state.workflow_fsm.advance() # init => data_entry_complete
212
+ st.session_state.workflow_fsm.advance() # data_entry_complete => data_entry_validated
213
 
214
  tab_log.info(f"{st.session_state.observations}")
215
 
 
231
  Once inference is complete, the top three predictions are shown.
232
  You can override the prediction by selecting a species from the dropdown.*""")
233
 
234
+
235
+ with tab_inference:
236
+ # test if the fsm is already at a point where results should be presented
237
+ cur_state_i = st.session_state.workflow_fsm.state_number
238
+ # here, if past manual inspection, we show the results
239
+ # elif past ml_completed, we show the results and the choice to manually validate
240
+ # else, we run the classifier (and show the results)
241
+ plan = "?"
242
+ if cur_state_i >= WorkflowState.MANUAL_REVIEW_COMPLETE.value:
243
+ plan = "show results"
244
+ elif cur_state_i >= WorkflowState.ML_COMPLETED.value:
245
+ plan = "present manual validation (with results shown)"
246
+ elif cur_state_i >= WorkflowState.DATA_VALIDATED.value:
247
+ plan = "run classifier"
248
+
249
+ st.info(f"Current state: {cur_state_i} [{WorkflowState.ML_COMPLETED.value}]. Plan: {plan}")
250
+
251
+ if plan == 'run classifier':
252
+
253
+ if tab_inference.button("Identify with cetacean classifier"):
254
+ #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
255
+ cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
256
+ revision=classifier_revision,
257
+ trust_remote_code=True)
258
+ r = cetacean_classify_list(cetacean_classifier)
259
+ if r:
260
+ st.session_state.workflow_fsm.advance() # data_entry_validated => ml_classification_started
261
+ refresh_progress()
262
+ #cetacean_classify_and_review(cetacean_classifier)
263
+ # now, we can trigger the next state, which is the manual review of the classifications
264
+ st.write(f"megatextc {cur_state_i}")
265
+ r = cetacean_show_classifications()
266
+ if r:
267
+ st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
268
+ refresh_progress()
269
+
270
+ elif plan == 'present manual validation (with results shown)':
271
+ # show the results and the choice to manually validate
272
+ st.write(f"megatexta {cur_state_i}")
273
+ r = cetacean_show_classifications()
274
+ if r:
275
+ st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
276
+
277
+ elif plan == 'show results':
278
+ r = cetacean_show_classifications()
279
+ # just showing it, no advance.
280
 
281
 
282
+ if 0:
283
+ if cur_state_i >= WorkflowState.ML_COMPLETED.value:
284
+ # ML DONE, let's show it
285
+ with tab_inference:
286
+ st.write(f"megatexta {cur_state_i}")
287
+ r = cetacean_show_classifications()
288
+ if r:
289
+ st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
290
+ else:
291
+ with tab_inference:
292
+ st.write(f"megatextb {cur_state_i}")
293
+ # st.session_state.workflow_fsm.advance() # init => data_entry_complete
294
+ if st.session_state.images is None: # TODO: with FSM we check a state, not just images.
295
+ # TODO: cleaner design to disable the button until data input done?
296
+ st.info("Please upload an image first.")
297
+ else:
298
+ pass
299
+
300
+
301
 
 
302
 
303
 
304
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
 
323
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
324
 
325
 
326
+ # after all other processing, we can show the stage/state
327
+ refresh_progress()
328
+
329
 
330
  if __name__ == "__main__":
331
  main()
src/utils/workflow_state.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transitions import Machine
2
+ from enum import Enum
3
+
4
+ OKBLUE = '\033[94m'
5
+ OKGREEN = '\033[92m'
6
+ OKCYAN = '\033[96m'
7
+ FAIL = '\033[91m'
8
+ ENDC = '\033[0m'
9
+
10
+ # define the states
11
+ # 0. init
12
+ # 1. data entry complete
13
+ # 2. data entry validated
14
+ # 3. ML classification started (can be long running on batch)
15
+ # 4. ML classification completed
16
+ # 5. manual inspection / adjustment of classification completed
17
+ # 6. data uploaded
18
+
19
+ states = ['init', 'data_entry_complete', 'data_entry_validated', 'ml_classification_started', 'ml_classification_completed', 'manual_inspection_completed', 'data_uploaded']
20
+
21
+
22
+ # define an enum for the states, automatically giving integers according to the position in the list
23
+ # - this is useful for the transitions
24
+ # maybe this needs to use setattr or similar
25
+ workflow_phases = Enum('StateEnum', {state: i for i, state in enumerate(states)})
26
+
27
+
28
+ class WorkflowState(Enum):
29
+ INIT = 0
30
+ DATA_ENTRY_COMPLETE = 1
31
+ DATA_VALIDATED = 2
32
+ #ML_STARTED = 3
33
+ ML_COMPLETED = 3
34
+ MANUAL_REVIEW_COMPLETE = 4
35
+ UPLOADED = 5
36
+
37
+
38
+ # TODO: refactor the FSM to have explicit named states, and write a helper function to determine the next state and advance to it.
39
+ # this allows either triggering by name, or being a bit lazy and saying "advance" and it will go to the next state..
40
+ # maybe a cleaner way is to say completed('X') and then whatever the next state from X is can be taken. Instead of knowing
41
+ # what the next state is (becausee that was supposed to be defined her in the specification, and not in each phase)
42
+ #
43
+ # also add a "did we pass stage X" function, by name. This will make it easy to choose what to present, what actions to do next etc.
44
+
45
+
46
+ class WorkflowFSM:
47
+ def __init__(self):
48
+ # Define states as strings (transitions requirement)
49
+ self.states = [state.name for state in WorkflowState]
50
+ # TODO: what is the point of the enum? I can just take the list and do an enumerate on it.??
51
+
52
+
53
+ # Create state machine
54
+ self.machine = Machine(
55
+ model=self,
56
+ states=self.states,
57
+ initial=WorkflowState.INIT.name,
58
+ )
59
+
60
+ # Add transitions for each state to the next state
61
+ for i in range(len(self.states) - 1):
62
+ self.machine.add_transition(
63
+ trigger='advance',
64
+ source=self.states[i],
65
+ dest=self.states[i + 1]
66
+ )
67
+
68
+ # Add reset transition
69
+ self.machine.add_transition(
70
+ trigger='reset',
71
+ source='*',
72
+ dest=WorkflowState.INIT.name
73
+ )
74
+
75
+ # Add callbacks for logging
76
+ self.machine.before_state_change = self._log_transition
77
+ self.machine.after_state_change = self._post_transition
78
+
79
+ def _cprint(self, msg:str, color:str=OKCYAN):
80
+ """Print colored message"""
81
+ print(f"{color}{msg}{ENDC}")
82
+
83
+
84
+ def _advance_state(self):
85
+ """Determine the next state based on current state"""
86
+ current_idx = self.states.index(self.state)
87
+ if current_idx < len(self.states) - 1:
88
+ return self.states[current_idx + 1]
89
+ return self.state # Stay in final state if already there
90
+
91
+ def _log_transition(self):
92
+ # TODO: use logger, not printing.
93
+ self._cprint(f"[FSM] -> Transitioning from {self.state}")
94
+
95
+ def _post_transition(self):
96
+ # TODO: use logger, not printing.
97
+ self._cprint(f"[FSM] -| Transitioned to {self.state}")
98
+
99
+
100
+ def advance(self):
101
+ if self.state_number < len(self.states) - 1:
102
+ self.trigger('advance')
103
+ else:
104
+ # maybe too aggressive to exception here?
105
+ raise RuntimeError("Already at final state")
106
+
107
+ @property
108
+ def state_number(self) -> int:
109
+ """Get the numerical value of current state"""
110
+ return self.states.index(self.state)
111
+
112
+ @property
113
+ def state_name(self) -> str:
114
+ """Get the name of current state"""
115
+ return self.state
116
+
117
+ # add a property for the number of states
118
+ @property
119
+ def num_states(self) -> int:
120
+ return len(self.states)
121
+