Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from abc import abstractmethod | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, cast | |
| import requests | |
| from dotenv import load_dotenv | |
| from griptape.artifacts import ListArtifact, TextArtifact, InfoArtifact, BaseArtifact | |
| from griptape.configs import Defaults | |
| from griptape.configs.drivers import ( | |
| OpenAiDriversConfig, | |
| ) | |
| from griptape.drivers import ( | |
| GriptapeCloudVectorStoreDriver, | |
| LocalStructureRunDriver, | |
| OpenAiChatPromptDriver, | |
| ) | |
| from griptape.engines.rag import RagEngine | |
| from griptape.engines.rag.modules import ( | |
| TextChunksResponseRagModule, | |
| VectorStoreRetrievalRagModule, | |
| ) | |
| from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
| from griptape.events import ( | |
| BaseEvent, | |
| EventBus, | |
| EventListener, | |
| FinishStructureRunEvent, | |
| ) | |
| from griptape.memory.structure import ConversationMemory | |
| from griptape.rules import Rule, Ruleset | |
| from griptape.structures import Agent, Workflow | |
| from griptape.tasks import CodeExecutionTask, StructureRunTask, ToolTask | |
| from griptape.tools import RagTool | |
| from statemachine import State, StateMachine | |
| from statemachine.factory import StateMachineMetaclass | |
| from parsers import UWConfigParser | |
| from uw_programmatic.single_question_machine import SingleQuestion | |
| logger = logging.getLogger(__name__) | |
| logging.getLogger("griptape").setLevel(logging.ERROR) | |
| if TYPE_CHECKING: | |
| from griptape.structures import Structure | |
| from griptape.tools import BaseTool | |
| from statemachine.event import Event | |
| load_dotenv() | |
| # Sets max tokens and OpenAI as the driver. | |
| Defaults.drivers_config = OpenAiDriversConfig( | |
| prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
| ) | |
| def custom_dict_merge(dict1: dict, dict2: dict) -> dict: | |
| result = dict1.copy() | |
| for key, value in dict2.items(): | |
| if key in result and isinstance(result[key], list) and isinstance(value, list): | |
| result[key] = result[key] + value | |
| else: | |
| result[key] = value | |
| return result | |
| class UWBaseMachine(StateMachine): | |
| """Base class for a machine. | |
| Attributes: | |
| config_file (Path): The path to the configuration file. | |
| config (dict): The configuration data. | |
| outputs_to_user (list[str]): Outputs to return to the user. | |
| """ | |
| def __init__(self, config_file: Path, **kwargs) -> None: | |
| self.config_parser = UWConfigParser(config_file) | |
| self.config = self.config_parser.parse() | |
| self._structures = {} | |
| self.vector_stores = {} # Store here in case needs multiple uses | |
| self.question_list: list = [] | |
| # For the parameters necessary from the user | |
| self.page_range: tuple = () | |
| self.question_number: int = 0 | |
| self.taxonomy: list = [] | |
| # To track give up | |
| self.give_up_count = 0 | |
| self.current_question_count = 0 | |
| # To keep vector stores on track | |
| self.kb_ids = {} | |
| self.rejected_questions: list = [] | |
| self.errored: bool = False | |
| self.state_status: dict[str, bool] = {} | |
| for key in self.state_transitions: | |
| self.state_status[key] = False | |
| def on_event(event: BaseEvent) -> None: | |
| """Takes in griptape events from eventbus and fixes them.""" | |
| print(f"Received Griptape event: {json.dumps(event.to_dict(), indent=2)}") | |
| try: | |
| self.send( | |
| "process_event", | |
| event_={"type": "griptape_event", "value": event.to_dict()}, | |
| ) | |
| except Exception as e: | |
| errormsg = f"Would not allow process_event to be sent. Check to see if it is defined in the config.yaml. Error:{e}" | |
| raise ValueError(errormsg) from e | |
| EventBus.clear_event_listeners() | |
| EventBus.add_event_listener( | |
| EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
| ) | |
| super().__init__() | |
| def available_events(self) -> list[str]: | |
| return self.current_state.transitions.unique_events | |
| def tools(self) -> dict[str, BaseTool]: | |
| """Returns the Tools for the machine.""" | |
| ... | |
| def _current_state_config(self) -> dict: | |
| return self.config["states"][self.current_state_value] | |
| def from_definition( # noqa: C901, PLR0912 | |
| cls, definition: dict, **extra_kwargs | |
| ) -> UWBaseMachine: | |
| try: | |
| states_instances = {} | |
| for state_id, state_kwargs in definition["states"].items(): | |
| # These are the relevant states that need GOAP. | |
| states_instances[state_id] = State(**state_kwargs, value=state_id) | |
| except Exception as e: | |
| errormsg = f"""Error in state definition: {e}. | |
| """ | |
| raise ValueError(errormsg) from e | |
| events = {} | |
| state_transitions = {} | |
| for event_name, transitions in definition["events"].items(): | |
| for transition_data in transitions: | |
| try: | |
| source_name = transition_data["from"] | |
| source = states_instances[source_name] | |
| target = states_instances[transition_data["to"]] | |
| relevance = "" | |
| if "relevance" in transition_data: | |
| relevance = transition_data["relevance"] | |
| if source_name not in state_transitions: | |
| state_transitions[source_name] = {event_name: relevance} | |
| else: | |
| state_transitions[source_name][event_name] = relevance | |
| except Exception as e: | |
| errormsg = f"Error:{e}. Please check your transitions to be sure each transition has a source and destination." | |
| raise ValueError(errormsg) from e | |
| transition = source.to( | |
| target, | |
| event=event_name, | |
| cond=transition_data.get("cond"), | |
| unless=transition_data.get("unless"), | |
| on=transition_data.get("on"), | |
| internal=transition_data.get("internal"), | |
| ) | |
| if event_name in events: | |
| events[event_name] |= transition | |
| else: | |
| events[event_name] = transition | |
| for state_id, state in states_instances.items(): | |
| if state_id not in ("end", "start"): | |
| transition = state.to( | |
| state, | |
| event="process_event", | |
| on=f"on_event_{state_id}", | |
| internal=True, | |
| ) | |
| if "process_event" in events: | |
| events["process_event"] |= transition | |
| else: | |
| events["process_event"] = transition | |
| attrs_mapper = { | |
| **extra_kwargs, | |
| **states_instances, | |
| **events, | |
| "state_transitions": state_transitions, | |
| } | |
| return cast( | |
| UWBaseMachine, | |
| StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**extra_kwargs), | |
| ) | |
| def from_config_file( | |
| cls, | |
| config_file: Path, | |
| **extra_kwargs, | |
| ) -> UWBaseMachine: | |
| """Creates a StateMachine class from a configuration file""" | |
| config_parser = UWConfigParser(config_file) | |
| config = config_parser.parse() | |
| extra_kwargs["config_file"] = config_file | |
| definition_states = { | |
| state_id: { | |
| "initial": state_value.get("initial", False), | |
| "final": state_value.get("final", False), | |
| } | |
| for state_id, state_value in config["states"].items() | |
| } | |
| definition_events = { | |
| event_name: list(event_value["transitions"]) | |
| for event_name, event_value in config["events"].items() | |
| } | |
| definition = {"states": definition_states, "events": definition_events} | |
| return cls.from_definition(definition, **extra_kwargs) | |
| def start_machine(self) -> None: | |
| """Starts the machine.""" | |
| ... | |
| def reset_structures(self) -> None: | |
| """Resets the structures.""" | |
| self._structures = {} | |
| def on_enter_state(self, source: State, state: State, event: Event) -> None: | |
| print(f"Transitioning from {source} to {state} with event {event}") | |
| def get_structure(self, structure_id: str) -> Structure: | |
| global_structure_config = self.config["structures"][structure_id] | |
| state_structure_config = self._current_state_config.get("structures", {}).get( | |
| structure_id, {} | |
| ) | |
| structure_config = custom_dict_merge( | |
| global_structure_config, state_structure_config | |
| ) | |
| if structure_id not in self._structures: | |
| # Initialize Structure with all the expensive setup | |
| structure = Agent( | |
| id=structure_id, | |
| conversation_memory=ConversationMemory(), | |
| ) | |
| self._structures[structure_id] = structure | |
| # Create a new clone with state-specific stuff | |
| structure = self._structures[structure_id] | |
| structure = Agent( | |
| id=structure.id, | |
| prompt_driver=structure.prompt_driver, | |
| conversation_memory=structure.conversation_memory, | |
| rulesets=[ | |
| *self._get_structure_rulesets(structure_config.get("ruleset_ids", [])), | |
| ], | |
| ) | |
| print(f"Structure: {structure_id}") | |
| for ruleset in structure.rulesets: | |
| for rule in ruleset.rules: | |
| print(f"Rule: {rule.value}") | |
| return structure | |
| def _get_structure_rulesets(self, ruleset_ids: list[str]) -> list[Ruleset]: | |
| ruleset_configs = [ | |
| self.config["rulesets"][ruleset_id] for ruleset_id in ruleset_ids | |
| ] | |
| # Convert ruleset configs to Rulesets | |
| return [ | |
| Ruleset( | |
| name=ruleset_config["name"], | |
| rules=[Rule(rule) for rule in ruleset_config["rules"]], | |
| ) | |
| for ruleset_config in ruleset_configs | |
| ] | |
| def retrieve_vector_stores(self) -> None: | |
| base_url = "https://cloud.griptape.ai/api/" | |
| kb_url = f"{base_url}/knowledge-bases" | |
| headers = {"Authorization": f"Bearer {os.getenv('GT_CLOUD_API_KEY')}"} | |
| response = requests.get(url=kb_url, headers=headers) | |
| response.raise_for_status() | |
| all_kbs = {} | |
| if response.status_code == 200: | |
| data = response.json() | |
| next_page = data["pagination"]["next_page"] | |
| while next_page is not None: | |
| for kb in data["knowledge_bases"]: | |
| name = kb["name"] | |
| kb_id = kb["knowledge_base_id"] | |
| if "KB_section" in name: | |
| all_kbs[name] = kb_id | |
| page_url = kb_url + f"?page={next_page}" | |
| response = requests.get(url=page_url, headers=headers) | |
| response.raise_for_status() | |
| data = response.json() | |
| next_page = data["pagination"]["next_page"] | |
| else: | |
| raise ValueError(response.status_code) | |
| self.kb_ids = all_kbs | |
| # ALL METHODS RELATING TO THE WORKFLOW AND PIPELINE ARE BELOW THIS LINE | |
| # This is the overarching workflow. Creates a workflow with get_single_question x amount of times. | |
| # def get_questions_workflow(self) -> Workflow: | |
| workflow = Workflow(id="create_question_workflow") | |
| # How many questions still need to be created? | |
| for _ in range(self.question_number - len(self.question_list)): | |
| task = StructureRunTask( | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=self.get_single_question | |
| ), | |
| child_ids=["end_task"], | |
| ) | |
| # Create X amount of workflows to run for X amount of questions needed. | |
| workflow.add_task(task) | |
| end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
| workflow.add_task(end_task) | |
| return workflow | |
| def workflow_cet(self, task: CodeExecutionTask) -> BaseArtifact: | |
| question_machine = SingleQuestion.create_statemachine( | |
| self.taxonomy, self.kb_ids, self.page_range | |
| ) | |
| question_machine.send("start_up") | |
| if question_machine.rejected: | |
| if question_machine.reject_reason == "BAD KB PAGE RANGE": | |
| return InfoArtifact("Bad KB Range") | |
| self.rejected_questions.append(question_machine.generated_question) | |
| return InfoArtifact("Question is Rejected") | |
| return TextArtifact(question_machine.generated_question) | |
| def get_questions_workflow(self) -> Workflow: | |
| workflow = Workflow(id="create_question_workflow") | |
| # How many questions still need to be created? | |
| for _ in range(self.question_number - len(self.question_list)): | |
| task = CodeExecutionTask( | |
| on_run=self.workflow_cet, | |
| child_ids=["end_task"], | |
| ) | |
| # Create X amount of workflows to run for X amount of questions needed. | |
| workflow.add_task(task) | |
| end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
| workflow.add_task(end_task) | |
| return workflow | |
| # Ends the get_questions_workflow. Compiles all workflow outputs into one output. | |
| def end_workflow(self, task: CodeExecutionTask) -> ListArtifact: | |
| parent_outputs = task.parent_outputs | |
| questions = [] | |
| for outputs in parent_outputs.values(): | |
| if outputs.type == "InfoArtifact": | |
| if outputs.value == "Bad KB Range": | |
| self.errored = True | |
| self.send("error_to_start") | |
| return ListArtifact([]) | |
| continue | |
| questions.append(outputs) | |
| return ListArtifact(questions) | |