Spaces:
Sleeping
Sleeping
File size: 3,597 Bytes
00bdefd 4854d2c 00bdefd 4854d2c 00bdefd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from transitions import Machine
from typing import List
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
OKCYAN = '\033[96m'
FAIL = '\033[91m'
ENDC = '\033[0m'
FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated',
#'ml_classification_started',
'ml_classification_completed',
'manual_inspection_completed', 'data_uploaded']
class WorkflowFSM:
def __init__(self, state_sequence: List[str]):
self.state_sequence = state_sequence
self.state_dict = {state: i for i, state in enumerate(state_sequence)}
# Create state machine
self.machine = Machine(
model=self,
states=state_sequence,
initial=state_sequence[0],
)
# For each state (except the last), add a completion transition to the next state
for i in range(len(state_sequence) - 1):
current_state = state_sequence[i]
next_state = state_sequence[i + 1]
self.machine.add_transition(
trigger=f'complete_{current_state}',
source=current_state,
dest=next_state,
conditions=[f'is_in_{current_state}']
)
# Dynamically add a condition method for each state
setattr(self, f'is_in_{current_state}',
lambda s=current_state: self.is_in_state(s))
# Add callbacks for logging
self.machine.before_state_change = self._log_transition
self.machine.after_state_change = self._post_transition
def is_in_state(self, state_name: str) -> bool:
"""Check if we're currently in the specified state"""
return self.state == state_name
def complete_current_state(self) -> bool:
"""
Signal that the current state is complete.
Returns True if state transition occurred, False otherwise.
"""
current_state = self.state
trigger_name = f'complete_{current_state}'
if hasattr(self, trigger_name):
try:
trigger_func = getattr(self, trigger_name)
trigger_func()
return True
except:
return False
return False
# add a helper method, to find out if a given state has been reached/passed
# we first need to get the index of the current state
# then the index of the argument state
# compare, and return boolean
def is_in_state_or_beyond(self, state_name: str) -> bool:
"""Check if we have reached or passed the specified state"""
if state_name not in self.state_dict:
raise ValueError(f"Invalid state: {state_name}")
return self.state_dict[state_name] <= self.state_dict[self.state]
@property
def current_state(self) -> str:
"""Get the current state name"""
return self.state
@property
def current_state_index(self) -> int:
"""Get the current state index"""
return self.state_dict[self.state]
@property
def num_states(self) -> int:
return len(self.state_sequence)
def _log_transition(self):
# TODO: use logger, not printing.
self._cprint(f"[FSM] -> Transitioning from {self.current_state}")
def _post_transition(self):
# TODO: use logger, not printing.
self._cprint(f"[FSM] -| Transitioned to {self.current_state}")
def _cprint(self, msg:str, color:str=OKCYAN):
"""Print colored message"""
print(f"{color}{msg}{ENDC}")
|