Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from abc import abstractmethod | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, cast | |
| import requests | |
| from dotenv import load_dotenv | |
| from griptape.artifacts import ListArtifact, TextArtifact | |
| from griptape.configs import Defaults | |
| from griptape.configs.drivers import ( | |
| OpenAiDriversConfig, | |
| ) | |
| from griptape.drivers import ( | |
| GriptapeCloudVectorStoreDriver, | |
| LocalStructureRunDriver, | |
| OpenAiChatPromptDriver, | |
| ) | |
| from griptape.engines.rag import RagEngine | |
| from griptape.engines.rag.modules import ( | |
| TextChunksResponseRagModule, | |
| VectorStoreRetrievalRagModule, | |
| ) | |
| from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
| from griptape.events import ( | |
| BaseEvent, | |
| EventBus, | |
| EventListener, | |
| FinishStructureRunEvent, | |
| ) | |
| from griptape.memory.structure import ConversationMemory | |
| from griptape.rules import Rule, Ruleset | |
| from griptape.structures import Agent, Workflow | |
| from griptape.tasks import CodeExecutionTask, StructureRunTask, ToolTask | |
| from griptape.tools import RagTool | |
| from statemachine import State, StateMachine | |
| from statemachine.factory import StateMachineMetaclass | |
| from griptape_statemachine.parsers.uw_config_parser import UWConfigParser | |
| logger = logging.getLogger(__name__) | |
| logging.getLogger("griptape").setLevel(logging.ERROR) | |
| if TYPE_CHECKING: | |
| from griptape.structures import Structure | |
| from griptape.tools import BaseTool | |
| from statemachine.event import Event | |
| load_dotenv() | |
| Defaults.drivers_config = OpenAiDriversConfig( | |
| prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
| ) | |
| def custom_dict_merge(dict1: dict, dict2: dict) -> dict: | |
| result = dict1.copy() | |
| for key, value in dict2.items(): | |
| if key in result and isinstance(result[key], list) and isinstance(value, list): | |
| result[key] = result[key] + value | |
| else: | |
| result[key] = value | |
| return result | |
| class UWBaseMachine(StateMachine): | |
| """Base class for a machine. | |
| Attributes: | |
| config_file (Path): The path to the configuration file. | |
| config (dict): The configuration data. | |
| outputs_to_user (list[str]): Outputs to return to the user. | |
| """ | |
| def __init__(self, config_file: Path, **kwargs) -> None: | |
| self.config_parser = UWConfigParser(config_file) | |
| self.config = self.config_parser.parse() | |
| self._structures = {} | |
| self.vector_stores = {} # Store here in case needs multiple uses | |
| self.question_list: list = [] | |
| # For the parameters necessary from the user | |
| self.page_range: tuple = () | |
| self.question_number: int = 0 | |
| self.taxonomy: list = [] | |
| self.give_up_count = 0 | |
| self.current_question_count = 0 | |
| self.state_status: dict[str, bool] = {} | |
| for key in self.state_transitions: | |
| self.state_status[key] = False | |
| def on_event(event: BaseEvent) -> None: | |
| """Takes in griptape events from eventbus and fixes them.""" | |
| print(f"Received Griptape event: {json.dumps(event.to_dict(), indent=2)}") | |
| try: | |
| self.send( | |
| "process_event", | |
| event_={"type": "griptape_event", "value": event.to_dict()}, | |
| ) | |
| except Exception as e: | |
| errormsg = f"Would not allow process_event to be sent. Check to see if it is defined in the config.yaml. Error:{e}" | |
| raise ValueError(errormsg) from e | |
| EventBus.clear_event_listeners() | |
| EventBus.add_event_listener( | |
| EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
| ) | |
| super().__init__() | |
| def available_events(self) -> list[str]: | |
| return self.current_state.transitions.unique_events | |
| def tools(self) -> dict[str, BaseTool]: | |
| """Returns the Tools for the machine.""" | |
| ... | |
| def _current_state_config(self) -> dict: | |
| return self.config["states"][self.current_state_value] | |
| def from_definition( # noqa: C901, PLR0912 | |
| cls, definition: dict, **extra_kwargs | |
| ) -> UWBaseMachine: | |
| try: | |
| states_instances = {} | |
| for state_id, state_kwargs in definition["states"].items(): | |
| # These are the relevant states that need GOAP. | |
| states_instances[state_id] = State(**state_kwargs, value=state_id) | |
| except Exception as e: | |
| errormsg = f"""Error in state definition: {e}. | |
| """ | |
| raise ValueError(errormsg) from e | |
| events = {} | |
| state_transitions = {} | |
| for event_name, transitions in definition["events"].items(): | |
| for transition_data in transitions: | |
| try: | |
| source_name = transition_data["from"] | |
| source = states_instances[source_name] | |
| target = states_instances[transition_data["to"]] | |
| relevance = "" | |
| if "relevance" in transition_data: | |
| relevance = transition_data["relevance"] | |
| if source_name not in state_transitions: | |
| state_transitions[source_name] = {event_name: relevance} | |
| else: | |
| state_transitions[source_name][event_name] = relevance | |
| except Exception as e: | |
| errormsg = f"Error:{e}. Please check your transitions to be sure each transition has a source and destination." | |
| raise ValueError(errormsg) from e | |
| transition = source.to( | |
| target, | |
| event=event_name, | |
| cond=transition_data.get("cond"), | |
| unless=transition_data.get("unless"), | |
| on=transition_data.get("on"), | |
| internal=transition_data.get("internal"), | |
| ) | |
| if event_name in events: | |
| events[event_name] |= transition | |
| else: | |
| events[event_name] = transition | |
| for state_id, state in states_instances.items(): | |
| if state_id not in ("end", "start"): | |
| transition = state.to( | |
| state, | |
| event="process_event", | |
| on=f"on_event_{state_id}", | |
| internal=True, | |
| ) | |
| if "process_event" in events: | |
| events["process_event"] |= transition | |
| else: | |
| events["process_event"] = transition | |
| attrs_mapper = { | |
| **extra_kwargs, | |
| **states_instances, | |
| **events, | |
| "state_transitions": state_transitions, | |
| } | |
| return cast( | |
| UWBaseMachine, | |
| StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**extra_kwargs), | |
| ) | |
| def from_config_file( | |
| cls, | |
| config_file: Path, | |
| **extra_kwargs, | |
| ) -> UWBaseMachine: | |
| """Creates a StateMachine class from a configuration file""" | |
| config_parser = UWConfigParser(config_file) | |
| config = config_parser.parse() | |
| extra_kwargs["config_file"] = config_file | |
| definition_states = { | |
| state_id: { | |
| "initial": state_value.get("initial", False), | |
| "final": state_value.get("final", False), | |
| } | |
| for state_id, state_value in config["states"].items() | |
| } | |
| definition_events = { | |
| event_name: list(event_value["transitions"]) | |
| for event_name, event_value in config["events"].items() | |
| } | |
| definition = {"states": definition_states, "events": definition_events} | |
| return cls.from_definition(definition, **extra_kwargs) | |
| def start_machine(self) -> None: | |
| """Starts the machine.""" | |
| ... | |
| def reset_structures(self) -> None: | |
| """Resets the structures.""" | |
| self._structures = {} | |
| def on_enter_state(self, source: State, state: State, event: Event) -> None: | |
| print(f"Transitioning from {source} to {state} with event {event}") | |
| def get_structure(self, structure_id: str) -> Structure: | |
| global_structure_config = self.config["structures"][structure_id] | |
| state_structure_config = self._current_state_config.get("structures", {}).get( | |
| structure_id, {} | |
| ) | |
| structure_config = custom_dict_merge( | |
| global_structure_config, state_structure_config | |
| ) | |
| if structure_id not in self._structures: | |
| # Initialize Structure with all the expensive setup | |
| structure = Agent( | |
| id=structure_id, | |
| conversation_memory=ConversationMemory(), | |
| ) | |
| self._structures[structure_id] = structure | |
| # Create a new clone with state-specific stuff | |
| structure = self._structures[structure_id] | |
| structure = Agent( | |
| id=structure.id, | |
| prompt_driver=structure.prompt_driver, | |
| conversation_memory=structure.conversation_memory, | |
| rulesets=[ | |
| *self._get_structure_rulesets(structure_config.get("ruleset_ids", [])), | |
| ], | |
| ) | |
| print(f"Structure: {structure_id}") | |
| for ruleset in structure.rulesets: | |
| for rule in ruleset.rules: | |
| print(f"Rule: {rule.value}") | |
| return structure | |
| def _get_structure_rulesets(self, ruleset_ids: list[str]) -> list[Ruleset]: | |
| ruleset_configs = [ | |
| self.config["rulesets"][ruleset_id] for ruleset_id in ruleset_ids | |
| ] | |
| # Convert ruleset configs to Rulesets | |
| return [ | |
| Ruleset( | |
| name=ruleset_config["name"], | |
| rules=[Rule(rule) for rule in ruleset_config["rules"]], | |
| ) | |
| for ruleset_config in ruleset_configs | |
| ] | |
| def get_prompt_by_structure(self, structure_id: str) -> str | None: | |
| try: | |
| state_structure_config = self._current_state_config.get( | |
| "structures", {} | |
| ).get(structure_id, {}) | |
| global_structure_config = self.config["structures"][structure_id] | |
| except KeyError: | |
| return None | |
| prompt_id = None | |
| if "prompt_id" in global_structure_config: | |
| prompt_id = global_structure_config["prompt_id"] | |
| elif "prompt_id" in state_structure_config: | |
| prompt_id = state_structure_config["prompt_id"] | |
| else: | |
| return None | |
| return self.config["prompts"][prompt_id]["prompt"] | |
| def get_prompt_by_id(self, prompt_id: str) -> str | None: | |
| prompt_config = self.config["prompts"] | |
| if prompt_id in prompt_config: | |
| return prompt_config[prompt_id]["prompt"] | |
| return None | |
| # ALL METHODS RELATING TO THE WORKFLOW AND PIPELINE | |
| def end_workflow(self, task: CodeExecutionTask) -> ListArtifact: | |
| parent_outputs = task.parent_outputs | |
| questions = [] | |
| for outputs in parent_outputs.values(): | |
| if outputs.type == "InfoArtifact": | |
| continue | |
| questions.append(outputs) | |
| return ListArtifact(questions) | |
| def get_questions_workflow(self) -> Workflow: | |
| workflow = Workflow(id="create_question_workflow") | |
| # How many questions still need to be created | |
| for _ in range(self.question_number - len(self.question_list)): | |
| task = StructureRunTask( | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=self.get_single_question | |
| ), | |
| child_ids=["end_task"], | |
| ) | |
| workflow.add_task(task) | |
| end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
| workflow.add_task(end_task) | |
| return workflow | |
| def single_question_last_task(self, task: CodeExecutionTask) -> TextArtifact: | |
| parent_outputs = task.parent_outputs | |
| wrong_answers = parent_outputs["wrong_answers"].value # Output is a list | |
| wrong_answers = wrong_answers.split("\n") | |
| question_and_answer = parent_outputs["get_question"].value # Output is a json | |
| try: | |
| question_and_answer = json.loads(question_and_answer) | |
| except: | |
| question_and_answer = question_and_answer.split("\n")[1:] | |
| question_and_answer = "".join(question_and_answer) | |
| question_and_answer = json.loads(question_and_answer) | |
| inputs = task.input.value.split(",") | |
| question = { | |
| "Question": question_and_answer["Question"], | |
| "Answer": question_and_answer["Answer"], | |
| "Wrong Answers": wrong_answers, | |
| "Page": inputs[0], | |
| "Taxonomy": inputs[1], | |
| } | |
| return TextArtifact(question) | |
| def get_question_for_wrong_answers(self, task: CodeExecutionTask) -> TextArtifact: | |
| parent_outputs = task.parent_outputs | |
| question = parent_outputs["get_question"].value | |
| question = json.loads(question)["Question"] | |
| return TextArtifact(question) | |
| def get_separated_answer_for_wrong_answers( | |
| self, task: CodeExecutionTask | |
| ) -> TextArtifact: | |
| parent_outputs = task.parent_outputs | |
| answer = parent_outputs["get_question"].value | |
| print(answer) | |
| answer = json.loads(answer)["Answer"] | |
| return TextArtifact(answer) | |
| def make_rag_structure( | |
| self, vector_store: GriptapeCloudVectorStoreDriver | |
| ) -> Structure: | |
| if vector_store: | |
| tool = self.build_rag_tool(self.build_rag_engine(vector_store)) | |
| use_rag_task = ToolTask(tool=tool) | |
| return Agent(tasks=[use_rag_task]) | |
| errormsg = "No Vector Store" | |
| raise ValueError(errormsg) | |
| def get_single_question(self) -> Workflow: | |
| question_generator = Workflow(id="single_question") | |
| taxonomy = random.choice(self.taxonomy) | |
| taxonomyprompt = { | |
| "Knowledge": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'define', 'list', 'state', 'identify','label'.", | |
| "Comprehension": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'explain', 'predict', 'interpret', 'infer', 'summarize', 'convert','give an example of x'.", | |
| "Application": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The structure of the question should be randomly chosen from: 'How could x be used to y?', 'How would you show/make use of/modify/demonstrate/solve/apply x to conditions y?'", | |
| } | |
| pages, driver = self.get_vector_store_id_from_page() | |
| get_information = StructureRunTask( | |
| id="information_task", | |
| input="What is the information in KB?", | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=lambda: self.make_rag_structure(driver) | |
| ), | |
| child_ids=["get_question"], | |
| ) | |
| # Get KBs and select it, assign it to the structure or create the structure right here. | |
| # Rules for subject matter expert: return only a json with question and answer as keys. | |
| generate_q_task = StructureRunTask( | |
| id="get_question", | |
| input=taxonomyprompt[taxonomy], | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=lambda: self.get_structure("subject_matter_expert") | |
| ), | |
| parent_ids=["information_task"], | |
| ) | |
| get_question_code_task = CodeExecutionTask( | |
| id="get_only_question", | |
| on_run=self.get_question_for_wrong_answers, | |
| parent_ids=["get_question"], | |
| child_ids=["wrong_answers"], | |
| ) | |
| get_separated_answer_code_task = CodeExecutionTask( | |
| id="get_separated_answer", | |
| on_run=self.get_separated_answer_for_wrong_answers, | |
| parent_ids=["get_question"], | |
| child_ids=["wrong_answers"], | |
| ) | |
| generate_wrong_answers = StructureRunTask( | |
| id="wrong_answers", | |
| input="""Write and return three incorrect answers for this question: {{parent_outputs['get_separated_question']}}. The correct answer to the question is: {{parent_outputs['get_separated_answer']}}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {{parent_outputs['information_task']}}""", | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=lambda: self.get_structure("wrong_answers_generator") | |
| ), | |
| parent_ids=["get_only_question", "information_task"], | |
| ) | |
| compile_task = CodeExecutionTask( | |
| id="compile_task", | |
| input=f"{pages}, {taxonomy}", | |
| on_run=self.single_question_last_task, | |
| parent_ids=["wrong_answers", "get_question"], | |
| ) | |
| question_generator.add_tasks( | |
| get_information, | |
| generate_q_task, | |
| get_question_code_task, | |
| get_separated_answer_code_task, | |
| generate_wrong_answers, | |
| compile_task, | |
| ) | |
| return question_generator | |
| def get_vector_store_id_from_page( | |
| self, | |
| ) -> tuple[str, GriptapeCloudVectorStoreDriver]: | |
| base_url = "https://cloud.griptape.ai/api/" | |
| kb_url = f"{base_url}/knowledge-bases" | |
| headers = {"Authorization": f"Bearer {os.getenv('GT_CLOUD_API_KEY')}"} | |
| # TODO: This needs to change when I have my own bucket. Right now, I'm doing the 10 most recently made KBs | |
| response = requests.get(url=kb_url, headers=headers) | |
| response.raise_for_status() | |
| if response.status_code == 200: | |
| data = response.json() | |
| possible_kbs = {} | |
| for kb in data["knowledge_bases"]: | |
| name = kb["name"] | |
| if "KB_section" not in name: | |
| continue | |
| page_nums = name.split("p")[1:] | |
| start_page = int(page_nums[0].split("-")[0]) | |
| end_page = int(page_nums[1]) | |
| if end_page <= self.page_range[1] and start_page >= self.page_range[0]: | |
| possible_kbs[kb["knowledge_base_id"]] = f"{start_page}-{end_page}" | |
| kb_id = random.choice(list(possible_kbs.keys())) | |
| page_value = possible_kbs[kb_id] # TODO: This won't help at all actually | |
| return page_value, GriptapeCloudVectorStoreDriver( | |
| api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
| knowledge_base_id=kb_id, | |
| ) | |
| else: | |
| raise ValueError(response.status_code) | |
| def get_taxonomy_vs(self) -> GriptapeCloudVectorStoreDriver: | |
| return GriptapeCloudVectorStoreDriver( | |
| api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
| knowledge_base_id="2c3a6f19-51a8-43c3-8445-c7fbe06bf460", | |
| ) | |
| def build_rag_engine( | |
| self, vector_store_driver: GriptapeCloudVectorStoreDriver | |
| ) -> RagEngine: | |
| return RagEngine( | |
| retrieval_stage=RetrievalRagStage( | |
| retrieval_modules=[ | |
| VectorStoreRetrievalRagModule( | |
| vector_store_driver=vector_store_driver, | |
| ) | |
| ], | |
| ), | |
| response_stage=ResponseRagStage( | |
| response_modules=[TextChunksResponseRagModule()] | |
| ), | |
| ) | |
| def build_rag_tool(self, engine: RagEngine) -> RagTool: | |
| return RagTool( | |
| description="Contains information about the textbook. Use it ONLY for context.", | |
| rag_engine=engine, | |
| ) | |