Spaces:
Sleeping
feat: first implementation of an FSM to keep track of phase
Browse files- classifier_image gets split into multiple functions: run inference;
manual review/validation (with display of results); just display of
results
- workflow_state implements a FSM using `transitions`. It is a bit too
simple in the end, as the idea was to have a trigger that allows
moving to the next state without needing to know what that state is
called (otherwise the specification of pathways by data structs
doesn't simplify it). But it is too buggy this way, you can advance
in places that you shouldn't. So will refactor this.
- in main
- we convert some more session_state to dicts, to handle image batches
- add a simple widget to show progress through the workflow
- the main effort to stop losing progress is on the tab_inference.
lots of testing state to check what action to take. you can see
several attempts, to clean up now I understand a big bug was in
gating everything by the inference button.
- requirements.txt +4 -0
- src/classifier/classifier_image.py +105 -1
- src/main.py +102 -16
- src/utils/workflow_state.py +121 -0
@@ -10,6 +10,10 @@ streamlit_folium==0.23.1
|
|
10 |
|
11 |
# backend
|
12 |
datasets==3.0.2
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
# running ML models
|
|
|
10 |
|
11 |
# backend
|
12 |
datasets==3.0.2
|
13 |
+
# - FSM
|
14 |
+
transitions==0.9.2
|
15 |
+
# optional, dev for the FSM (diagrams)
|
16 |
+
# pyperclip==1.9.0
|
17 |
|
18 |
|
19 |
# running ML models
|
@@ -11,7 +11,111 @@ from hf_push_observations import push_observations
|
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
files = st.session_state.files
|
16 |
images = st.session_state.images
|
17 |
observations = st.session_state.observations
|
|
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
13 |
|
14 |
+
# need to divide this into two functions, one for the classification and one for the display
|
15 |
+
# it is currently somewhat interleaved, not totally clear how to separate them.
|
16 |
+
# perhaps we have more stages than I realised.
|
17 |
+
# ML started, ML completed, manual review completed, data uploaded
|
18 |
+
|
19 |
+
# for now, let's implement the division between ML classification, and display+manual review.
|
20 |
+
|
21 |
+
def cetacean_classify_list(cetacean_classifier):
|
22 |
+
success = False
|
23 |
+
|
24 |
+
files = st.session_state.files
|
25 |
+
images = st.session_state.images
|
26 |
+
observations = st.session_state.observations
|
27 |
+
|
28 |
+
#batch_size, row_size, page = gridder(files)
|
29 |
+
#grid = st.columns(row_size)
|
30 |
+
#col = 0
|
31 |
+
|
32 |
+
for file in files:
|
33 |
+
key = file.name
|
34 |
+
image = images[key]
|
35 |
+
|
36 |
+
observation = observations[key].to_dict()
|
37 |
+
# run classifier model on `image`, and persistently store the output
|
38 |
+
out = cetacean_classifier(image) # get top 3 matches
|
39 |
+
st.session_state.whale_prediction1[key] = out['predictions'][0]
|
40 |
+
st.session_state.classify_whale_done[key] = True # TODO 25.01 unclear what this is for;
|
41 |
+
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
|
42 |
+
g_logger.info(msg)
|
43 |
+
|
44 |
+
observations[key].set_top_predictions(out['predictions'][:])
|
45 |
+
|
46 |
+
st.session_state.public_observation[key] = observation #
|
47 |
+
msg = f"[D] full observation after inference: {observation}"
|
48 |
+
g_logger.debug(msg)
|
49 |
+
print(msg)
|
50 |
+
|
51 |
+
# TODO: add some mech to test if it was successful.
|
52 |
+
success = True
|
53 |
+
st.balloons()
|
54 |
+
return success
|
55 |
+
|
56 |
+
def cetacean_show_classifications():
|
57 |
+
st.write("TOP TEXT")
|
58 |
+
st.write("Reviewing the classifications :construction:")
|
59 |
+
files = st.session_state.files
|
60 |
+
images = st.session_state.images
|
61 |
+
observations = st.session_state.observations
|
62 |
+
|
63 |
+
batch_size, row_size, page = gridder(files)
|
64 |
+
|
65 |
+
grid = st.columns(row_size)
|
66 |
+
col = 0
|
67 |
+
|
68 |
+
for file in files:
|
69 |
+
key = file.name
|
70 |
+
image = images[key]
|
71 |
+
|
72 |
+
with grid[col]:
|
73 |
+
st.image(image, use_column_width=True)
|
74 |
+
observation = observations[key].to_dict()
|
75 |
+
# fetch the classification results
|
76 |
+
# run classifier model on `image`, and persistently store the output
|
77 |
+
msg = f"[D]2b classify_whale_done ({file}): {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
|
78 |
+
g_logger.info(msg)
|
79 |
+
|
80 |
+
# dropdown for selecting/overriding the species prediction
|
81 |
+
# TODO: the "it's done" flag seems to get reset when we re-load the tab. Not quite right.
|
82 |
+
if not st.session_state.classify_whale_done[key]:
|
83 |
+
#selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
84 |
+
# TODO: ask LV why it is in the sidebar, and not in the grid
|
85 |
+
selected_class = st.selectbox("Species", viewer.WHALE_CLASSES,
|
86 |
+
index=None, placeholder="Species not yet identified...",
|
87 |
+
disabled=True, key=f"cldd_{key}")
|
88 |
+
else:
|
89 |
+
pred1 = st.session_state.whale_prediction1[key]
|
90 |
+
# get index of pred1 from WHALE_CLASSES, none if not present
|
91 |
+
print(f"[D] pred1: {pred1}")
|
92 |
+
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
93 |
+
selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
|
94 |
+
|
95 |
+
observation['predicted_class'] = selected_class
|
96 |
+
if selected_class != st.session_state.whale_prediction1[key]:
|
97 |
+
observation['class_overriden'] = selected_class
|
98 |
+
|
99 |
+
st.session_state.public_observation = observation
|
100 |
+
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
|
101 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
102 |
+
st.markdown(metadata2md())
|
103 |
+
|
104 |
+
msg = f"[D] full observation after inference: {observation}"
|
105 |
+
g_logger.debug(msg)
|
106 |
+
print(msg)
|
107 |
+
# TODO: add a link to more info on the model, next to the button.
|
108 |
+
whale_classes = observations[key].top_predictions
|
109 |
+
# render images for the top 3 (that is what the model api returns)
|
110 |
+
n = len(whale_classes)
|
111 |
+
st.markdown(f"Top {n} Predictions for {file.name}")
|
112 |
+
for i in range(n):
|
113 |
+
viewer.display_whale(whale_classes, i)
|
114 |
+
col = (col + 1) % row_size
|
115 |
+
return True
|
116 |
+
|
117 |
+
|
118 |
+
def cetacean_classify_and_review(cetacean_classifier):
|
119 |
files = st.session_state.files
|
120 |
images = st.session_state.images
|
121 |
observations = st.session_state.observations
|
@@ -17,8 +17,11 @@ import whale_viewer as viewer
|
|
17 |
from input.input_handling import setup_input
|
18 |
from maps.alps_map import present_alps_map
|
19 |
from maps.obs_map import present_obs_map
|
|
|
|
|
20 |
from utils.st_logs import setup_logging, parse_log_buffer
|
21 |
-
from
|
|
|
22 |
from classifier.classifier_hotdog import hotdog_classify
|
23 |
|
24 |
|
@@ -57,14 +60,31 @@ if "public_observation" not in st.session_state:
|
|
57 |
st.session_state.public_observation = {}
|
58 |
|
59 |
if "classify_whale_done" not in st.session_state:
|
60 |
-
st.session_state.classify_whale_done =
|
61 |
|
62 |
if "whale_prediction1" not in st.session_state:
|
63 |
-
st.session_state.whale_prediction1 =
|
64 |
|
65 |
if "tab_log" not in st.session_state:
|
66 |
st.session_state.tab_log = None
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def main() -> None:
|
70 |
"""
|
@@ -98,11 +118,15 @@ def main() -> None:
|
|
98 |
st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
|
99 |
st.session_state.tab_log = tab_log
|
100 |
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
103 |
observations = setup_input(viewcontainer=st.sidebar)
|
104 |
|
105 |
-
|
106 |
if 0:## WIP
|
107 |
# goal of this code is to allow the user to override the ML prediction, before transmitting an observations
|
108 |
predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
|
@@ -118,7 +142,7 @@ def main() -> None:
|
|
118 |
with tab_map:
|
119 |
# visual structure: a couple of toggles at the top, then the map inlcuding a
|
120 |
# dropdown for tileset selection.
|
121 |
-
|
122 |
tab_map_ui_cols = st.columns(2)
|
123 |
with tab_map_ui_cols[0]:
|
124 |
show_db_points = st.toggle("Show Points from DB", True)
|
@@ -178,9 +202,14 @@ def main() -> None:
|
|
178 |
|
179 |
# Display submitted observation
|
180 |
if st.sidebar.button("Validate"):
|
|
|
|
|
181 |
# create a dictionary with the submitted observation
|
182 |
submitted_data = observations
|
183 |
st.session_state.observations = observations
|
|
|
|
|
|
|
184 |
|
185 |
tab_log.info(f"{st.session_state.observations}")
|
186 |
|
@@ -202,20 +231,74 @@ def main() -> None:
|
|
202 |
Once inference is complete, the top three predictions are shown.
|
203 |
You can override the prediction by selecting a species from the dropdown.*""")
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
|
212 |
-
if
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
|
219 |
|
220 |
|
221 |
# inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
|
@@ -240,6 +323,9 @@ def main() -> None:
|
|
240 |
hotdog_classify(pipeline_hot_dog, tab_hotdogs)
|
241 |
|
242 |
|
|
|
|
|
|
|
243 |
|
244 |
if __name__ == "__main__":
|
245 |
main()
|
|
|
17 |
from input.input_handling import setup_input
|
18 |
from maps.alps_map import present_alps_map
|
19 |
from maps.obs_map import present_obs_map
|
20 |
+
from maps.obs_map import add_header_text as add_header_text_obs_map
|
21 |
+
|
22 |
from utils.st_logs import setup_logging, parse_log_buffer
|
23 |
+
from utils.workflow_state import WorkflowFSM, WorkflowState
|
24 |
+
from classifier.classifier_image import cetacean_classify_and_review, cetacean_classify_list, cetacean_show_classifications
|
25 |
from classifier.classifier_hotdog import hotdog_classify
|
26 |
|
27 |
|
|
|
60 |
st.session_state.public_observation = {}
|
61 |
|
62 |
if "classify_whale_done" not in st.session_state:
|
63 |
+
st.session_state.classify_whale_done = {}
|
64 |
|
65 |
if "whale_prediction1" not in st.session_state:
|
66 |
+
st.session_state.whale_prediction1 = {}
|
67 |
|
68 |
if "tab_log" not in st.session_state:
|
69 |
st.session_state.tab_log = None
|
70 |
|
71 |
+
if "workflow_fsm" not in st.session_state:
|
72 |
+
# create and init the state machine
|
73 |
+
st.session_state.workflow_fsm = WorkflowFSM()
|
74 |
+
|
75 |
+
# add progress indicator to session_state
|
76 |
+
if "progress" not in st.session_state:
|
77 |
+
with st.sidebar:
|
78 |
+
st.session_state.disp_progress = [st.empty(), st.empty()]
|
79 |
+
|
80 |
+
def refresh_progress():
|
81 |
+
with st.sidebar:
|
82 |
+
tot = st.session_state.workflow_fsm.num_states
|
83 |
+
cur_i = st.session_state.workflow_fsm.state_number
|
84 |
+
cur_t = st.session_state.workflow_fsm.state_name
|
85 |
+
st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
|
86 |
+
st.session_state.disp_progress[1].progress(cur_i/tot)
|
87 |
+
|
88 |
|
89 |
def main() -> None:
|
90 |
"""
|
|
|
118 |
st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
|
119 |
st.session_state.tab_log = tab_log
|
120 |
|
121 |
+
refresh_progress()
|
122 |
+
# add button to sidebar, with the callback to refesh_progress
|
123 |
+
st.sidebar.button("Refresh Progress", on_click=refresh_progress)
|
124 |
+
|
125 |
|
126 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
127 |
observations = setup_input(viewcontainer=st.sidebar)
|
128 |
|
129 |
+
|
130 |
if 0:## WIP
|
131 |
# goal of this code is to allow the user to override the ML prediction, before transmitting an observations
|
132 |
predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
|
|
|
142 |
with tab_map:
|
143 |
# visual structure: a couple of toggles at the top, then the map inlcuding a
|
144 |
# dropdown for tileset selection.
|
145 |
+
add_header_text_obs_map()
|
146 |
tab_map_ui_cols = st.columns(2)
|
147 |
with tab_map_ui_cols[0]:
|
148 |
show_db_points = st.toggle("Show Points from DB", True)
|
|
|
202 |
|
203 |
# Display submitted observation
|
204 |
if st.sidebar.button("Validate"):
|
205 |
+
# TODO 25.01 - it seems unclear what validation is actually happening *after* the button click.
|
206 |
+
|
207 |
# create a dictionary with the submitted observation
|
208 |
submitted_data = observations
|
209 |
st.session_state.observations = observations
|
210 |
+
# advance two steps, since the code for enabling the validate button is in a different branch right now
|
211 |
+
st.session_state.workflow_fsm.advance() # init => data_entry_complete
|
212 |
+
st.session_state.workflow_fsm.advance() # data_entry_complete => data_entry_validated
|
213 |
|
214 |
tab_log.info(f"{st.session_state.observations}")
|
215 |
|
|
|
231 |
Once inference is complete, the top three predictions are shown.
|
232 |
You can override the prediction by selecting a species from the dropdown.*""")
|
233 |
|
234 |
+
|
235 |
+
with tab_inference:
|
236 |
+
# test if the fsm is already at a point where results should be presented
|
237 |
+
cur_state_i = st.session_state.workflow_fsm.state_number
|
238 |
+
# here, if past manual inspection, we show the results
|
239 |
+
# elif past ml_completed, we show the results and the choice to manually validate
|
240 |
+
# else, we run the classifier (and show the results)
|
241 |
+
plan = "?"
|
242 |
+
if cur_state_i >= WorkflowState.MANUAL_REVIEW_COMPLETE.value:
|
243 |
+
plan = "show results"
|
244 |
+
elif cur_state_i >= WorkflowState.ML_COMPLETED.value:
|
245 |
+
plan = "present manual validation (with results shown)"
|
246 |
+
elif cur_state_i >= WorkflowState.DATA_VALIDATED.value:
|
247 |
+
plan = "run classifier"
|
248 |
+
|
249 |
+
st.info(f"Current state: {cur_state_i} [{WorkflowState.ML_COMPLETED.value}]. Plan: {plan}")
|
250 |
+
|
251 |
+
if plan == 'run classifier':
|
252 |
+
|
253 |
+
if tab_inference.button("Identify with cetacean classifier"):
|
254 |
+
#pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
|
255 |
+
cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
|
256 |
+
revision=classifier_revision,
|
257 |
+
trust_remote_code=True)
|
258 |
+
r = cetacean_classify_list(cetacean_classifier)
|
259 |
+
if r:
|
260 |
+
st.session_state.workflow_fsm.advance() # data_entry_validated => ml_classification_started
|
261 |
+
refresh_progress()
|
262 |
+
#cetacean_classify_and_review(cetacean_classifier)
|
263 |
+
# now, we can trigger the next state, which is the manual review of the classifications
|
264 |
+
st.write(f"megatextc {cur_state_i}")
|
265 |
+
r = cetacean_show_classifications()
|
266 |
+
if r:
|
267 |
+
st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
|
268 |
+
refresh_progress()
|
269 |
+
|
270 |
+
elif plan == 'present manual validation (with results shown)':
|
271 |
+
# show the results and the choice to manually validate
|
272 |
+
st.write(f"megatexta {cur_state_i}")
|
273 |
+
r = cetacean_show_classifications()
|
274 |
+
if r:
|
275 |
+
st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
|
276 |
+
|
277 |
+
elif plan == 'show results':
|
278 |
+
r = cetacean_show_classifications()
|
279 |
+
# just showing it, no advance.
|
280 |
|
281 |
|
282 |
+
if 0:
|
283 |
+
if cur_state_i >= WorkflowState.ML_COMPLETED.value:
|
284 |
+
# ML DONE, let's show it
|
285 |
+
with tab_inference:
|
286 |
+
st.write(f"megatexta {cur_state_i}")
|
287 |
+
r = cetacean_show_classifications()
|
288 |
+
if r:
|
289 |
+
st.session_state.workflow_fsm.advance() # ml_classification_completed => manual_inspection_completed
|
290 |
+
else:
|
291 |
+
with tab_inference:
|
292 |
+
st.write(f"megatextb {cur_state_i}")
|
293 |
+
# st.session_state.workflow_fsm.advance() # init => data_entry_complete
|
294 |
+
if st.session_state.images is None: # TODO: with FSM we check a state, not just images.
|
295 |
+
# TODO: cleaner design to disable the button until data input done?
|
296 |
+
st.info("Please upload an image first.")
|
297 |
+
else:
|
298 |
+
pass
|
299 |
+
|
300 |
+
|
301 |
|
|
|
302 |
|
303 |
|
304 |
# inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
|
|
|
323 |
hotdog_classify(pipeline_hot_dog, tab_hotdogs)
|
324 |
|
325 |
|
326 |
+
# after all other processing, we can show the stage/state
|
327 |
+
refresh_progress()
|
328 |
+
|
329 |
|
330 |
if __name__ == "__main__":
|
331 |
main()
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transitions import Machine
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
OKBLUE = '\033[94m'
|
5 |
+
OKGREEN = '\033[92m'
|
6 |
+
OKCYAN = '\033[96m'
|
7 |
+
FAIL = '\033[91m'
|
8 |
+
ENDC = '\033[0m'
|
9 |
+
|
10 |
+
# define the states
|
11 |
+
# 0. init
|
12 |
+
# 1. data entry complete
|
13 |
+
# 2. data entry validated
|
14 |
+
# 3. ML classification started (can be long running on batch)
|
15 |
+
# 4. ML classification completed
|
16 |
+
# 5. manual inspection / adjustment of classification completed
|
17 |
+
# 6. data uploaded
|
18 |
+
|
19 |
+
states = ['init', 'data_entry_complete', 'data_entry_validated', 'ml_classification_started', 'ml_classification_completed', 'manual_inspection_completed', 'data_uploaded']
|
20 |
+
|
21 |
+
|
22 |
+
# define an enum for the states, automatically giving integers according to the position in the list
|
23 |
+
# - this is useful for the transitions
|
24 |
+
# maybe this needs to use setattr or similar
|
25 |
+
workflow_phases = Enum('StateEnum', {state: i for i, state in enumerate(states)})
|
26 |
+
|
27 |
+
|
28 |
+
class WorkflowState(Enum):
|
29 |
+
INIT = 0
|
30 |
+
DATA_ENTRY_COMPLETE = 1
|
31 |
+
DATA_VALIDATED = 2
|
32 |
+
#ML_STARTED = 3
|
33 |
+
ML_COMPLETED = 3
|
34 |
+
MANUAL_REVIEW_COMPLETE = 4
|
35 |
+
UPLOADED = 5
|
36 |
+
|
37 |
+
|
38 |
+
# TODO: refactor the FSM to have explicit named states, and write a helper function to determine the next state and advance to it.
|
39 |
+
# this allows either triggering by name, or being a bit lazy and saying "advance" and it will go to the next state..
|
40 |
+
# maybe a cleaner way is to say completed('X') and then whatever the next state from X is can be taken. Instead of knowing
|
41 |
+
# what the next state is (becausee that was supposed to be defined her in the specification, and not in each phase)
|
42 |
+
#
|
43 |
+
# also add a "did we pass stage X" function, by name. This will make it easy to choose what to present, what actions to do next etc.
|
44 |
+
|
45 |
+
|
46 |
+
class WorkflowFSM:
|
47 |
+
def __init__(self):
|
48 |
+
# Define states as strings (transitions requirement)
|
49 |
+
self.states = [state.name for state in WorkflowState]
|
50 |
+
# TODO: what is the point of the enum? I can just take the list and do an enumerate on it.??
|
51 |
+
|
52 |
+
|
53 |
+
# Create state machine
|
54 |
+
self.machine = Machine(
|
55 |
+
model=self,
|
56 |
+
states=self.states,
|
57 |
+
initial=WorkflowState.INIT.name,
|
58 |
+
)
|
59 |
+
|
60 |
+
# Add transitions for each state to the next state
|
61 |
+
for i in range(len(self.states) - 1):
|
62 |
+
self.machine.add_transition(
|
63 |
+
trigger='advance',
|
64 |
+
source=self.states[i],
|
65 |
+
dest=self.states[i + 1]
|
66 |
+
)
|
67 |
+
|
68 |
+
# Add reset transition
|
69 |
+
self.machine.add_transition(
|
70 |
+
trigger='reset',
|
71 |
+
source='*',
|
72 |
+
dest=WorkflowState.INIT.name
|
73 |
+
)
|
74 |
+
|
75 |
+
# Add callbacks for logging
|
76 |
+
self.machine.before_state_change = self._log_transition
|
77 |
+
self.machine.after_state_change = self._post_transition
|
78 |
+
|
79 |
+
def _cprint(self, msg:str, color:str=OKCYAN):
|
80 |
+
"""Print colored message"""
|
81 |
+
print(f"{color}{msg}{ENDC}")
|
82 |
+
|
83 |
+
|
84 |
+
def _advance_state(self):
|
85 |
+
"""Determine the next state based on current state"""
|
86 |
+
current_idx = self.states.index(self.state)
|
87 |
+
if current_idx < len(self.states) - 1:
|
88 |
+
return self.states[current_idx + 1]
|
89 |
+
return self.state # Stay in final state if already there
|
90 |
+
|
91 |
+
def _log_transition(self):
|
92 |
+
# TODO: use logger, not printing.
|
93 |
+
self._cprint(f"[FSM] -> Transitioning from {self.state}")
|
94 |
+
|
95 |
+
def _post_transition(self):
|
96 |
+
# TODO: use logger, not printing.
|
97 |
+
self._cprint(f"[FSM] -| Transitioned to {self.state}")
|
98 |
+
|
99 |
+
|
100 |
+
def advance(self):
|
101 |
+
if self.state_number < len(self.states) - 1:
|
102 |
+
self.trigger('advance')
|
103 |
+
else:
|
104 |
+
# maybe too aggressive to exception here?
|
105 |
+
raise RuntimeError("Already at final state")
|
106 |
+
|
107 |
+
@property
|
108 |
+
def state_number(self) -> int:
|
109 |
+
"""Get the numerical value of current state"""
|
110 |
+
return self.states.index(self.state)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def state_name(self) -> str:
|
114 |
+
"""Get the name of current state"""
|
115 |
+
return self.state
|
116 |
+
|
117 |
+
# add a property for the number of states
|
118 |
+
@property
|
119 |
+
def num_states(self) -> int:
|
120 |
+
return len(self.states)
|
121 |
+
|