Spaces:
Running
Running
from __future__ import annotations | |
import json | |
import logging | |
import os | |
import random | |
from abc import abstractmethod | |
from pathlib import Path | |
from typing import TYPE_CHECKING, cast | |
import requests | |
from dotenv import load_dotenv | |
from griptape.artifacts import ListArtifact, TextArtifact, InfoArtifact, BaseArtifact | |
from griptape.configs import Defaults | |
from griptape.configs.drivers import ( | |
OpenAiDriversConfig, | |
) | |
from griptape.drivers import ( | |
GriptapeCloudVectorStoreDriver, | |
LocalStructureRunDriver, | |
OpenAiChatPromptDriver, | |
) | |
from griptape.engines.rag import RagEngine | |
from griptape.engines.rag.modules import ( | |
TextChunksResponseRagModule, | |
VectorStoreRetrievalRagModule, | |
) | |
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
from griptape.events import ( | |
BaseEvent, | |
EventBus, | |
EventListener, | |
FinishStructureRunEvent, | |
) | |
from griptape.memory.structure import ConversationMemory | |
from griptape.rules import Rule, Ruleset | |
from griptape.structures import Agent, Workflow | |
from griptape.tasks import CodeExecutionTask, StructureRunTask, ToolTask | |
from griptape.tools import RagTool | |
from statemachine import State, StateMachine | |
from statemachine.factory import StateMachineMetaclass | |
from parsers import UWConfigParser | |
from uw_programmatic.single_question_machine import SingleQuestion | |
logger = logging.getLogger(__name__) | |
logging.getLogger("griptape").setLevel(logging.ERROR) | |
if TYPE_CHECKING: | |
from griptape.structures import Structure | |
from griptape.tools import BaseTool | |
from statemachine.event import Event | |
load_dotenv() | |
# Sets max tokens and OpenAI as the driver. | |
Defaults.drivers_config = OpenAiDriversConfig( | |
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
) | |
def custom_dict_merge(dict1: dict, dict2: dict) -> dict: | |
result = dict1.copy() | |
for key, value in dict2.items(): | |
if key in result and isinstance(result[key], list) and isinstance(value, list): | |
result[key] = result[key] + value | |
else: | |
result[key] = value | |
return result | |
class UWBaseMachine(StateMachine): | |
"""Base class for a machine. | |
Attributes: | |
config_file (Path): The path to the configuration file. | |
config (dict): The configuration data. | |
outputs_to_user (list[str]): Outputs to return to the user. | |
""" | |
def __init__(self, config_file: Path, **kwargs) -> None: | |
self.config_parser = UWConfigParser(config_file) | |
self.config = self.config_parser.parse() | |
self._structures = {} | |
self.vector_stores = {} # Store here in case needs multiple uses | |
self.question_list: list = [] | |
# For the parameters necessary from the user | |
self.page_range: tuple = () | |
self.question_number: int = 0 | |
self.taxonomy: list = [] | |
# To track give up | |
self.give_up_count = 0 | |
self.current_question_count = 0 | |
# To keep vector stores on track | |
self.kb_ids = {} | |
self.rejected_questions: list = [] | |
self.errored: bool = False | |
self.state_status: dict[str, bool] = {} | |
for key in self.state_transitions: | |
self.state_status[key] = False | |
def on_event(event: BaseEvent) -> None: | |
"""Takes in griptape events from eventbus and fixes them.""" | |
print(f"Received Griptape event: {json.dumps(event.to_dict(), indent=2)}") | |
try: | |
self.send( | |
"process_event", | |
event_={"type": "griptape_event", "value": event.to_dict()}, | |
) | |
except Exception as e: | |
errormsg = f"Would not allow process_event to be sent. Check to see if it is defined in the config.yaml. Error:{e}" | |
raise ValueError(errormsg) from e | |
EventBus.clear_event_listeners() | |
EventBus.add_event_listener( | |
EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
) | |
super().__init__() | |
def available_events(self) -> list[str]: | |
return self.current_state.transitions.unique_events | |
def tools(self) -> dict[str, BaseTool]: | |
"""Returns the Tools for the machine.""" | |
... | |
def _current_state_config(self) -> dict: | |
return self.config["states"][self.current_state_value] | |
def from_definition( # noqa: C901, PLR0912 | |
cls, definition: dict, **extra_kwargs | |
) -> UWBaseMachine: | |
try: | |
states_instances = {} | |
for state_id, state_kwargs in definition["states"].items(): | |
# These are the relevant states that need GOAP. | |
states_instances[state_id] = State(**state_kwargs, value=state_id) | |
except Exception as e: | |
errormsg = f"""Error in state definition: {e}. | |
""" | |
raise ValueError(errormsg) from e | |
events = {} | |
state_transitions = {} | |
for event_name, transitions in definition["events"].items(): | |
for transition_data in transitions: | |
try: | |
source_name = transition_data["from"] | |
source = states_instances[source_name] | |
target = states_instances[transition_data["to"]] | |
relevance = "" | |
if "relevance" in transition_data: | |
relevance = transition_data["relevance"] | |
if source_name not in state_transitions: | |
state_transitions[source_name] = {event_name: relevance} | |
else: | |
state_transitions[source_name][event_name] = relevance | |
except Exception as e: | |
errormsg = f"Error:{e}. Please check your transitions to be sure each transition has a source and destination." | |
raise ValueError(errormsg) from e | |
transition = source.to( | |
target, | |
event=event_name, | |
cond=transition_data.get("cond"), | |
unless=transition_data.get("unless"), | |
on=transition_data.get("on"), | |
internal=transition_data.get("internal"), | |
) | |
if event_name in events: | |
events[event_name] |= transition | |
else: | |
events[event_name] = transition | |
for state_id, state in states_instances.items(): | |
if state_id not in ("end", "start"): | |
transition = state.to( | |
state, | |
event="process_event", | |
on=f"on_event_{state_id}", | |
internal=True, | |
) | |
if "process_event" in events: | |
events["process_event"] |= transition | |
else: | |
events["process_event"] = transition | |
attrs_mapper = { | |
**extra_kwargs, | |
**states_instances, | |
**events, | |
"state_transitions": state_transitions, | |
} | |
return cast( | |
UWBaseMachine, | |
StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**extra_kwargs), | |
) | |
def from_config_file( | |
cls, | |
config_file: Path, | |
**extra_kwargs, | |
) -> UWBaseMachine: | |
"""Creates a StateMachine class from a configuration file""" | |
config_parser = UWConfigParser(config_file) | |
config = config_parser.parse() | |
extra_kwargs["config_file"] = config_file | |
definition_states = { | |
state_id: { | |
"initial": state_value.get("initial", False), | |
"final": state_value.get("final", False), | |
} | |
for state_id, state_value in config["states"].items() | |
} | |
definition_events = { | |
event_name: list(event_value["transitions"]) | |
for event_name, event_value in config["events"].items() | |
} | |
definition = {"states": definition_states, "events": definition_events} | |
return cls.from_definition(definition, **extra_kwargs) | |
def start_machine(self) -> None: | |
"""Starts the machine.""" | |
... | |
def reset_structures(self) -> None: | |
"""Resets the structures.""" | |
self._structures = {} | |
def on_enter_state(self, source: State, state: State, event: Event) -> None: | |
print(f"Transitioning from {source} to {state} with event {event}") | |
def get_structure(self, structure_id: str) -> Structure: | |
global_structure_config = self.config["structures"][structure_id] | |
state_structure_config = self._current_state_config.get("structures", {}).get( | |
structure_id, {} | |
) | |
structure_config = custom_dict_merge( | |
global_structure_config, state_structure_config | |
) | |
if structure_id not in self._structures: | |
# Initialize Structure with all the expensive setup | |
structure = Agent( | |
id=structure_id, | |
conversation_memory=ConversationMemory(), | |
) | |
self._structures[structure_id] = structure | |
# Create a new clone with state-specific stuff | |
structure = self._structures[structure_id] | |
structure = Agent( | |
id=structure.id, | |
prompt_driver=structure.prompt_driver, | |
conversation_memory=structure.conversation_memory, | |
rulesets=[ | |
*self._get_structure_rulesets(structure_config.get("ruleset_ids", [])), | |
], | |
) | |
print(f"Structure: {structure_id}") | |
for ruleset in structure.rulesets: | |
for rule in ruleset.rules: | |
print(f"Rule: {rule.value}") | |
return structure | |
def _get_structure_rulesets(self, ruleset_ids: list[str]) -> list[Ruleset]: | |
ruleset_configs = [ | |
self.config["rulesets"][ruleset_id] for ruleset_id in ruleset_ids | |
] | |
# Convert ruleset configs to Rulesets | |
return [ | |
Ruleset( | |
name=ruleset_config["name"], | |
rules=[Rule(rule) for rule in ruleset_config["rules"]], | |
) | |
for ruleset_config in ruleset_configs | |
] | |
def retrieve_vector_stores(self) -> None: | |
base_url = "https://cloud.griptape.ai/api/" | |
kb_url = f"{base_url}/knowledge-bases" | |
headers = {"Authorization": f"Bearer {os.getenv('GT_CLOUD_API_KEY')}"} | |
response = requests.get(url=kb_url, headers=headers) | |
response.raise_for_status() | |
all_kbs = {} | |
if response.status_code == 200: | |
data = response.json() | |
next_page = data["pagination"]["next_page"] | |
while next_page is not None: | |
for kb in data["knowledge_bases"]: | |
name = kb["name"] | |
kb_id = kb["knowledge_base_id"] | |
if "KB_section" in name: | |
all_kbs[name] = kb_id | |
page_url = kb_url + f"?page={next_page}" | |
response = requests.get(url=page_url, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
next_page = data["pagination"]["next_page"] | |
else: | |
raise ValueError(response.status_code) | |
self.kb_ids = all_kbs | |
# ALL METHODS RELATING TO THE WORKFLOW AND PIPELINE ARE BELOW THIS LINE | |
# This is the overarching workflow. Creates a workflow with get_single_question x amount of times. | |
# def get_questions_workflow(self) -> Workflow: | |
workflow = Workflow(id="create_question_workflow") | |
# How many questions still need to be created? | |
for _ in range(self.question_number - len(self.question_list)): | |
task = StructureRunTask( | |
structure_run_driver=LocalStructureRunDriver( | |
create_structure=self.get_single_question | |
), | |
child_ids=["end_task"], | |
) | |
# Create X amount of workflows to run for X amount of questions needed. | |
workflow.add_task(task) | |
end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
workflow.add_task(end_task) | |
return workflow | |
def workflow_cet(self, task: CodeExecutionTask) -> BaseArtifact: | |
question_machine = SingleQuestion.create_statemachine( | |
self.taxonomy, self.kb_ids, self.page_range | |
) | |
question_machine.send("start_up") | |
if question_machine.rejected: | |
if question_machine.reject_reason == "BAD KB PAGE RANGE": | |
return InfoArtifact("Bad KB Range") | |
self.rejected_questions.append(question_machine.generated_question) | |
return InfoArtifact("Question is Rejected") | |
return TextArtifact(question_machine.generated_question) | |
def get_questions_workflow(self) -> Workflow: | |
workflow = Workflow(id="create_question_workflow") | |
# How many questions still need to be created? | |
for _ in range(self.question_number - len(self.question_list)): | |
task = CodeExecutionTask( | |
on_run=self.workflow_cet, | |
child_ids=["end_task"], | |
) | |
# Create X amount of workflows to run for X amount of questions needed. | |
workflow.add_task(task) | |
end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
workflow.add_task(end_task) | |
return workflow | |
# Ends the get_questions_workflow. Compiles all workflow outputs into one output. | |
def end_workflow(self, task: CodeExecutionTask) -> ListArtifact: | |
parent_outputs = task.parent_outputs | |
questions = [] | |
for outputs in parent_outputs.values(): | |
if outputs.type == "InfoArtifact": | |
if outputs.value == "Bad KB Range": | |
self.errored = True | |
self.send("error_to_start") | |
return ListArtifact([]) | |
continue | |
questions.append(outputs) | |
return ListArtifact(questions) | |