Spaces:
Sleeping
Sleeping
from transitions import Machine | |
from typing import List | |
OKBLUE = '\033[94m' | |
OKGREEN = '\033[92m' | |
OKCYAN = '\033[96m' | |
FAIL = '\033[91m' | |
ENDC = '\033[0m' | |
FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated', | |
#'ml_classification_started', | |
'ml_classification_completed', | |
'manual_inspection_completed', 'data_uploaded'] | |
class WorkflowFSM: | |
def __init__(self, state_sequence: List[str]): | |
self.state_sequence = state_sequence | |
self.state_dict = {state: i for i, state in enumerate(state_sequence)} | |
# Create state machine | |
self.machine = Machine( | |
model=self, | |
states=state_sequence, | |
initial=state_sequence[0], | |
) | |
# For each state (except the last), add a completion transition to the next state | |
for i in range(len(state_sequence) - 1): | |
current_state = state_sequence[i] | |
next_state = state_sequence[i + 1] | |
self.machine.add_transition( | |
trigger=f'complete_{current_state}', | |
source=current_state, | |
dest=next_state, | |
conditions=[f'is_in_{current_state}'] | |
) | |
# Dynamically add a condition method for each state | |
setattr(self, f'is_in_{current_state}', | |
lambda s=current_state: self.is_in_state(s)) | |
# Add callbacks for logging | |
self.machine.before_state_change = self._log_transition | |
self.machine.after_state_change = self._post_transition | |
def is_in_state(self, state_name: str) -> bool: | |
"""Check if we're currently in the specified state""" | |
return self.state == state_name | |
def complete_current_state(self) -> bool: | |
""" | |
Signal that the current state is complete. | |
Returns True if state transition occurred, False otherwise. | |
""" | |
current_state = self.state | |
trigger_name = f'complete_{current_state}' | |
if hasattr(self, trigger_name): | |
try: | |
trigger_func = getattr(self, trigger_name) | |
trigger_func() | |
return True | |
except: | |
return False | |
return False | |
# add a helper method, to find out if a given state has been reached/passed | |
# we first need to get the index of the current state | |
# then the index of the argument state | |
# compare, and return boolean | |
def is_in_state_or_beyond(self, state_name: str) -> bool: | |
"""Check if we have reached or passed the specified state""" | |
if state_name not in self.state_dict: | |
raise ValueError(f"Invalid state: {state_name}") | |
return self.state_dict[state_name] <= self.state_dict[self.state] | |
def current_state(self) -> str: | |
"""Get the current state name""" | |
return self.state | |
def current_state_index(self) -> int: | |
"""Get the current state index""" | |
return self.state_dict[self.state] | |
def num_states(self) -> int: | |
return len(self.state_sequence) | |
def _log_transition(self): | |
# TODO: use logger, not printing. | |
self._cprint(f"[FSM] -> Transitioning from {self.current_state}") | |
def _post_transition(self): | |
# TODO: use logger, not printing. | |
self._cprint(f"[FSM] -| Transitioned to {self.current_state}") | |
def _cprint(self, msg:str, color:str=OKCYAN): | |
"""Print colored message""" | |
print(f"{color}{msg}{ENDC}") | |