rmm commited on
Commit
5a21040
·
1 Parent(s): 80b4be6

Revert "feat: first implementation of an FSM to keep track of phase"

Browse files

This reverts commit 80b4be61b16c215c79aba50f7e39fde2bfc81755.

- I learned what I needed to but I don't like the FSM implementation,
and I created plenty of mess in main that doesn't need to remain.

--> reverting.

requirements.txt CHANGED
@@ -10,10 +10,6 @@ streamlit_folium==0.23.1
10
 
11
  # backend
12
  datasets==3.0.2
13
- # - FSM
14
- transitions==0.9.2
15
- # optional, dev for the FSM (diagrams)
16
- # pyperclip==1.9.0
17
 
18
 
19
  # running ML models
 
10
 
11
  # backend
12
  datasets==3.0.2
 
 
 
 
13
 
14
 
15
  # running ML models
src/classifier/classifier_image.py CHANGED
@@ -11,111 +11,7 @@ from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
- # need to divide this into two functions, one for the classification and one for the display
15
- # it is currently somewhat interleaved, not totally clear how to separate them.
16
- # perhaps we have more stages than I realised.
17
- # ML started, ML completed, manual review completed, data uploaded
18
-
19
- # for now, let's implement the division between ML classification, and display+manual review.
20
-
21
- def cetacean_classify_list(cetacean_classifier):
22
- success = False
23
-
24
- files = st.session_state.files
25
- images = st.session_state.images
26
- observations = st.session_state.observations
27
-
28
- #batch_size, row_size, page = gridder(files)
29
- #grid = st.columns(row_size)
30
- #col = 0
31
-
32
- for file in files:
33
- key = file.name
34
- image = images[key]
35
-
36
- observation = observations[key].to_dict()
37
- # run classifier model on `image`, and persistently store the output
38
- out = cetacean_classifier(image) # get top 3 matches
39
- st.session_state.whale_prediction1[key] = out['predictions'][0]
40
- st.session_state.classify_whale_done[key] = True # TODO 25.01 unclear what this is for;
41
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
42
- g_logger.info(msg)
43
-
44
- observations[key].set_top_predictions(out['predictions'][:])
45
-
46
- st.session_state.public_observation[key] = observation #
47
- msg = f"[D] full observation after inference: {observation}"
48
- g_logger.debug(msg)
49
- print(msg)
50
-
51
- # TODO: add some mech to test if it was successful.
52
- success = True
53
- st.balloons()
54
- return success
55
-
56
- def cetacean_show_classifications():
57
- st.write("TOP TEXT")
58
- st.write("Reviewing the classifications :construction:")
59
- files = st.session_state.files
60
- images = st.session_state.images
61
- observations = st.session_state.observations
62
-
63
- batch_size, row_size, page = gridder(files)
64
-
65
- grid = st.columns(row_size)
66
- col = 0
67
-
68
- for file in files:
69
- key = file.name
70
- image = images[key]
71
-
72
- with grid[col]:
73
- st.image(image, use_column_width=True)
74
- observation = observations[key].to_dict()
75
- # fetch the classification results
76
- # run classifier model on `image`, and persistently store the output
77
- msg = f"[D]2b classify_whale_done ({file}): {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
78
- g_logger.info(msg)
79
-
80
- # dropdown for selecting/overriding the species prediction
81
- # TODO: the "it's done" flag seems to get reset when we re-load the tab. Not quite right.
82
- if not st.session_state.classify_whale_done[key]:
83
- #selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
84
- # TODO: ask LV why it is in the sidebar, and not in the grid
85
- selected_class = st.selectbox("Species", viewer.WHALE_CLASSES,
86
- index=None, placeholder="Species not yet identified...",
87
- disabled=True, key=f"cldd_{key}")
88
- else:
89
- pred1 = st.session_state.whale_prediction1[key]
90
- # get index of pred1 from WHALE_CLASSES, none if not present
91
- print(f"[D] pred1: {pred1}")
92
- ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
93
- selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
94
-
95
- observation['predicted_class'] = selected_class
96
- if selected_class != st.session_state.whale_prediction1[key]:
97
- observation['class_overriden'] = selected_class
98
-
99
- st.session_state.public_observation = observation
100
- st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
101
- # TODO: the metadata only fills properly if `validate` was clicked.
102
- st.markdown(metadata2md())
103
-
104
- msg = f"[D] full observation after inference: {observation}"
105
- g_logger.debug(msg)
106
- print(msg)
107
- # TODO: add a link to more info on the model, next to the button.
108
- whale_classes = observations[key].top_predictions
109
- # render images for the top 3 (that is what the model api returns)
110
- n = len(whale_classes)
111
- st.markdown(f"Top {n} Predictions for {file.name}")
112
- for i in range(n):
113
- viewer.display_whale(whale_classes, i)
114
- col = (col + 1) % row_size
115
- return True
116
-
117
-
118
- def cetacean_classify_and_review(cetacean_classifier):
119
  files = st.session_state.files
120
  images = st.session_state.images
121
  observations = st.session_state.observations
 
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
+ def cetacean_classify(cetacean_classifier):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  files = st.session_state.files
16
  images = st.session_state.images
17
  observations = st.session_state.observations
src/main.py CHANGED
@@ -17,11 +17,8 @@ import whale_viewer as viewer
17
  from input.input_handling import setup_input
18
  from maps.alps_map import present_alps_map
19
  from maps.obs_map import present_obs_map
20
- from maps.obs_map import add_header_text as add_header_text_obs_map
21
-
22
  from utils.st_logs import setup_logging, parse_log_buffer
23
- from utils.workflow_state import WorkflowFSM, WorkflowState
24
- from classifier.classifier_image import cetacean_classify_and_review, cetacean_classify_list, cetacean_show_classifications
25
  from classifier.classifier_hotdog import hotdog_classify
26
 
27
 
@@ -60,31 +57,14 @@ if "public_observation" not in st.session_state:
60
  st.session_state.public_observation = {}
61
 
62
  if "classify_whale_done" not in st.session_state:
63
- st.session_state.classify_whale_done = {}
64
 
65
  if "whale_prediction1" not in st.session_state:
66
- st.session_state.whale_prediction1 = {}
67
 
68
  if "tab_log" not in st.session_state:
69
  st.session_state.tab_log = None
70
 
71
- if "workflow_fsm" not in st.session_state:
72
- # create and init the state machine
73
- st.session_state.workflow_fsm = WorkflowFSM()
74
-
75
- # add progress indicator to session_state
76
- if "progress" not in st.session_state:
77
- with st.sidebar:
78
- st.session_state.disp_progress = [st.empty(), st.empty()]
79
-
80
- def refresh_progress():
81
- with st.sidebar:
82
- tot = st.session_state.workflow_fsm.num_states
83
- cur_i = st.session_state.workflow_fsm.state_number
84
- cur_t = st.session_state.workflow_fsm.state_name
85
- st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
86
- st.session_state.disp_progress[1].progress(cur_i/tot)
87
-
88
 
89
  def main() -> None:
90
  """
@@ -118,15 +98,11 @@ def main() -> None:
118
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
119
  st.session_state.tab_log = tab_log
120
 
121
- refresh_progress()
122
- # add button to sidebar, with the callback to refesh_progress
123
- st.sidebar.button("Refresh Progress", on_click=refresh_progress)
124
-
125
 
126
  # create a sidebar, and parse all the input (returned as `observations` object)
127
  observations = setup_input(viewcontainer=st.sidebar)
128
 
129
-
130
  if 0:## WIP
131
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
132
  predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
@@ -142,7 +118,7 @@ def main() -> None:
142
  with tab_map:
143
  # visual structure: a couple of toggles at the top, then the map inlcuding a
144
  # dropdown for tileset selection.
145
- add_header_text_obs_map()
146
  tab_map_ui_cols = st.columns(2)
147
  with tab_map_ui_cols[0]:
148
  show_db_points = st.toggle("Show Points from DB", True)
@@ -202,14 +178,9 @@ def main() -> None:
202
 
203
  # Display submitted observation
204
  if st.sidebar.button("Validate"):
205
- # TODO 25.01 - it seems unclear what validation is actually happening *after* the button click.
206
-
207
  # create a dictionary with the submitted observation
208
  submitted_data = observations
209
  st.session_state.observations = observations
210
- # advance two steps, since the code for enabling the validate button is in a different branch right now
211
- st.session_state.workflow_fsm.advance() # init => data_entry_complete
212
- st.session_state.workflow_fsm.advance() # data_entry_complete => data_entry_validated
213
 
214
  tab_log.info(f"{st.session_state.observations}")
215
 
@@ -231,74 +202,20 @@ def main() -> None:
231
  Once inference is complete, the top three predictions are shown.
232
  You can override the prediction by selecting a species from the dropdown.*""")
233
 
234
-
235
- with tab_inference:
236
- # test if the fsm is already at a point where results should be presented
237
- cur_state_i = st.session_state.workflow_fsm.state_number
238
- # here, if past manual inspection, we show the results
239
- # elif past ml_completed, we show the results and the choice to manually validate
240
- # else, we run the classifier (and show the results)
241
- plan = "?"
242
- if cur_state_i >= WorkflowState.MANUAL_REVIEW_COMPLETE.value:
243
- plan = "show results"
244
- elif cur_state_i >= WorkflowState.ML_COMPLETED.value:
245
- plan = "present manual validation (with results shown)"
246
- elif cur_state_i >= WorkflowState.DATA_VALIDATED.value:
247
- plan = "run classifier"
248
-
249
- st.info(f"Current state: {cur_state_i} [{WorkflowState.ML_COMPLETED.value}]. Plan: {plan}")
250
-
251
- if plan == 'run classifier':
252
-
253
- if tab_inference.button("Identify with cetacean classifier"):
254
- #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
255
- cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
256
- revision=classifier_revision,
257
- trust_remote_code=True)
258
- r = cetacean_classify_list(cetacean_classifier)
259
- if r:
260
- st.session_state.workflow_fsm.advance() # data_entry_validated => ml_classification_started
261
- refresh_progress()
262
- #cetacean_classify_and_review(cetacean_classifier)
263
- # now, we can trigger the next state, which is the manual review of the classifications
264
- st.write(f"megatextc {cur_state_i}")
265
- r = cetacean_show_classifications()
266
- if r:
267
- st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
268
- refresh_progress()
269
-
270
- elif plan == 'present manual validation (with results shown)':
271
- # show the results and the choice to manually validate
272
- st.write(f"megatexta {cur_state_i}")
273
- r = cetacean_show_classifications()
274
- if r:
275
- st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
276
-
277
- elif plan == 'show results':
278
- r = cetacean_show_classifications()
279
- # just showing it, no advance.
280
 
281
 
282
- if 0:
283
- if cur_state_i >= WorkflowState.ML_COMPLETED.value:
284
- # ML DONE, let's show it
285
- with tab_inference:
286
- st.write(f"megatexta {cur_state_i}")
287
- r = cetacean_show_classifications()
288
- if r:
289
- st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
290
- else:
291
- with tab_inference:
292
- st.write(f"megatextb {cur_state_i}")
293
- # st.session_state.workflow_fsm.advance() # init => data_entry_complete
294
- if st.session_state.images is None: # TODO: with FSM we check a state, not just images.
295
- # TODO: cleaner design to disable the button until data input done?
296
- st.info("Please upload an image first.")
297
- else:
298
- pass
299
-
300
-
301
 
 
302
 
303
 
304
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
@@ -323,9 +240,6 @@ def main() -> None:
323
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
324
 
325
 
326
- # after all other processing, we can show the stage/state
327
- refresh_progress()
328
-
329
 
330
  if __name__ == "__main__":
331
  main()
 
17
  from input.input_handling import setup_input
18
  from maps.alps_map import present_alps_map
19
  from maps.obs_map import present_obs_map
 
 
20
  from utils.st_logs import setup_logging, parse_log_buffer
21
+ from classifier.classifier_image import cetacean_classify
 
22
  from classifier.classifier_hotdog import hotdog_classify
23
 
24
 
 
57
  st.session_state.public_observation = {}
58
 
59
  if "classify_whale_done" not in st.session_state:
60
+ st.session_state.classify_whale_done = False
61
 
62
  if "whale_prediction1" not in st.session_state:
63
+ st.session_state.whale_prediction1 = None
64
 
65
  if "tab_log" not in st.session_state:
66
  st.session_state.tab_log = None
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def main() -> None:
70
  """
 
98
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
99
  st.session_state.tab_log = tab_log
100
 
 
 
 
 
101
 
102
  # create a sidebar, and parse all the input (returned as `observations` object)
103
  observations = setup_input(viewcontainer=st.sidebar)
104
 
105
+
106
  if 0:## WIP
107
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
108
  predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
 
118
  with tab_map:
119
  # visual structure: a couple of toggles at the top, then the map inlcuding a
120
  # dropdown for tileset selection.
121
+ sw_map.add_header_text()
122
  tab_map_ui_cols = st.columns(2)
123
  with tab_map_ui_cols[0]:
124
  show_db_points = st.toggle("Show Points from DB", True)
 
178
 
179
  # Display submitted observation
180
  if st.sidebar.button("Validate"):
 
 
181
  # create a dictionary with the submitted observation
182
  submitted_data = observations
183
  st.session_state.observations = observations
 
 
 
184
 
185
  tab_log.info(f"{st.session_state.observations}")
186
 
 
202
  Once inference is complete, the top three predictions are shown.
203
  You can override the prediction by selecting a species from the dropdown.*""")
204
 
205
+ if tab_inference.button("Identify with cetacean classifier"):
206
+ #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
207
+ cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
208
+ revision=classifier_revision,
209
+ trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
 
212
+ if st.session_state.images is None:
213
+ # TODO: cleaner design to disable the button until data input done?
214
+ st.info("Please upload an image first.")
215
+ else:
216
+ cetacean_classify(cetacean_classifier)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+
219
 
220
 
221
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
 
240
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
241
 
242
 
 
 
 
243
 
244
  if __name__ == "__main__":
245
  main()
src/utils/workflow_state.py DELETED
@@ -1,121 +0,0 @@
1
- from transitions import Machine
2
- from enum import Enum
3
-
4
- OKBLUE = '\033[94m'
5
- OKGREEN = '\033[92m'
6
- OKCYAN = '\033[96m'
7
- FAIL = '\033[91m'
8
- ENDC = '\033[0m'
9
-
10
- # define the states
11
- # 0. init
12
- # 1. data entry complete
13
- # 2. data entry validated
14
- # 3. ML classification started (can be long running on batch)
15
- # 4. ML classification completed
16
- # 5. manual inspection / adjustment of classification completed
17
- # 6. data uploaded
18
-
19
- states = ['init', 'data_entry_complete', 'data_entry_validated', 'ml_classification_started', 'ml_classification_completed', 'manual_inspection_completed', 'data_uploaded']
20
-
21
-
22
- # define an enum for the states, automatically giving integers according to the position in the list
23
- # - this is useful for the transitions
24
- # maybe this needs to use setattr or similar
25
- workflow_phases = Enum('StateEnum', {state: i for i, state in enumerate(states)})
26
-
27
-
28
- class WorkflowState(Enum):
29
- INIT = 0
30
- DATA_ENTRY_COMPLETE = 1
31
- DATA_VALIDATED = 2
32
- #ML_STARTED = 3
33
- ML_COMPLETED = 3
34
- MANUAL_REVIEW_COMPLETE = 4
35
- UPLOADED = 5
36
-
37
-
38
- # TODO: refactor the FSM to have explicit named states, and write a helper function to determine the next state and advance to it.
39
- # this allows either triggering by name, or being a bit lazy and saying "advance" and it will go to the next state..
40
- # maybe a cleaner way is to say completed('X') and then whatever the next state from X is can be taken. Instead of knowing
41
- # what the next state is (becausee that was supposed to be defined her in the specification, and not in each phase)
42
- #
43
- # also add a "did we pass stage X" function, by name. This will make it easy to choose what to present, what actions to do next etc.
44
-
45
-
46
- class WorkflowFSM:
47
- def __init__(self):
48
- # Define states as strings (transitions requirement)
49
- self.states = [state.name for state in WorkflowState]
50
- # TODO: what is the point of the enum? I can just take the list and do an enumerate on it.??
51
-
52
-
53
- # Create state machine
54
- self.machine = Machine(
55
- model=self,
56
- states=self.states,
57
- initial=WorkflowState.INIT.name,
58
- )
59
-
60
- # Add transitions for each state to the next state
61
- for i in range(len(self.states) - 1):
62
- self.machine.add_transition(
63
- trigger='advance',
64
- source=self.states[i],
65
- dest=self.states[i + 1]
66
- )
67
-
68
- # Add reset transition
69
- self.machine.add_transition(
70
- trigger='reset',
71
- source='*',
72
- dest=WorkflowState.INIT.name
73
- )
74
-
75
- # Add callbacks for logging
76
- self.machine.before_state_change = self._log_transition
77
- self.machine.after_state_change = self._post_transition
78
-
79
- def _cprint(self, msg:str, color:str=OKCYAN):
80
- """Print colored message"""
81
- print(f"{color}{msg}{ENDC}")
82
-
83
-
84
- def _advance_state(self):
85
- """Determine the next state based on current state"""
86
- current_idx = self.states.index(self.state)
87
- if current_idx < len(self.states) - 1:
88
- return self.states[current_idx + 1]
89
- return self.state # Stay in final state if already there
90
-
91
- def _log_transition(self):
92
- # TODO: use logger, not printing.
93
- self._cprint(f"[FSM] -> Transitioning from {self.state}")
94
-
95
- def _post_transition(self):
96
- # TODO: use logger, not printing.
97
- self._cprint(f"[FSM] -| Transitioned to {self.state}")
98
-
99
-
100
- def advance(self):
101
- if self.state_number < len(self.states) - 1:
102
- self.trigger('advance')
103
- else:
104
- # maybe too aggressive to exception here?
105
- raise RuntimeError("Already at final state")
106
-
107
- @property
108
- def state_number(self) -> int:
109
- """Get the numerical value of current state"""
110
- return self.states.index(self.state)
111
-
112
- @property
113
- def state_name(self) -> str:
114
- """Get the name of current state"""
115
- return self.state
116
-
117
- # add a property for the number of states
118
- @property
119
- def num_states(self) -> int:
120
- return len(self.states)
121
-