File size: 3,597 Bytes
00bdefd
 
 
 
 
 
 
 
 
 
4854d2c
 
 
 
00bdefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4854d2c
 
 
 
 
 
 
 
 
 
 
 
 
00bdefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from transitions import Machine
from typing import List

OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
OKCYAN = '\033[96m'
FAIL = '\033[91m'
ENDC = '\033[0m'


FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated', 
              #'ml_classification_started', 
              'ml_classification_completed', 
              'manual_inspection_completed', 'data_uploaded']


class WorkflowFSM:
    def __init__(self, state_sequence: List[str]):
        self.state_sequence = state_sequence
        self.state_dict = {state: i for i, state in enumerate(state_sequence)}
        
        # Create state machine
        self.machine = Machine(
            model=self,
            states=state_sequence,
            initial=state_sequence[0],
        )

        # For each state (except the last), add a completion transition to the next state
        for i in range(len(state_sequence) - 1):
            current_state = state_sequence[i]
            next_state = state_sequence[i + 1]
            
            self.machine.add_transition(
                trigger=f'complete_{current_state}',
                source=current_state,
                dest=next_state,
                conditions=[f'is_in_{current_state}']
            )

            # Dynamically add a condition method for each state
            setattr(self, f'is_in_{current_state}',
                   lambda s=current_state: self.is_in_state(s))

        # Add callbacks for logging
        self.machine.before_state_change = self._log_transition
        self.machine.after_state_change = self._post_transition

    def is_in_state(self, state_name: str) -> bool:
        """Check if we're currently in the specified state"""
        return self.state == state_name

    def complete_current_state(self) -> bool:
        """
        Signal that the current state is complete.
        Returns True if state transition occurred, False otherwise.
        """
        current_state = self.state
        trigger_name = f'complete_{current_state}'
        
        if hasattr(self, trigger_name):
            try:
                trigger_func = getattr(self, trigger_name)
                trigger_func()
                return True
            except:
                return False
        return False

    # add a helper method, to find out if a given state has been reached/passed
    # we first need to get the index of the current state
    # then the index of the argument state
    # compare, and return boolean
    
    def is_in_state_or_beyond(self, state_name: str) -> bool:
        """Check if we have reached or passed the specified state"""
        if state_name not in self.state_dict:
            raise ValueError(f"Invalid state: {state_name}")
        
        return self.state_dict[state_name] <= self.state_dict[self.state]
    

    @property
    def current_state(self) -> str:
        """Get the current state name"""
        return self.state
    
    @property
    def current_state_index(self) -> int:
        """Get the current state index"""
        return self.state_dict[self.state]

    @property
    def num_states(self) -> int:
        return len(self.state_sequence)
 

    def _log_transition(self):
        # TODO: use logger, not printing. 
        self._cprint(f"[FSM] -> Transitioning from {self.current_state}")

    def _post_transition(self):
        # TODO: use logger, not printing. 
        self._cprint(f"[FSM] -| Transitioned to {self.current_state}")

    def _cprint(self, msg:str, color:str=OKCYAN):
        """Print colored message"""
        print(f"{color}{msg}{ENDC}")