Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import json | |
import logging | |
import os | |
import random | |
from abc import abstractmethod | |
from pathlib import Path | |
from typing import TYPE_CHECKING, cast | |
import requests | |
from dotenv import load_dotenv | |
from griptape.artifacts import ListArtifact, TextArtifact | |
from griptape.configs import Defaults | |
from griptape.configs.drivers import ( | |
OpenAiDriversConfig, | |
) | |
from griptape.drivers import ( | |
GriptapeCloudVectorStoreDriver, | |
LocalStructureRunDriver, | |
OpenAiChatPromptDriver, | |
) | |
from griptape.engines.rag import RagEngine | |
from griptape.engines.rag.modules import ( | |
TextChunksResponseRagModule, | |
VectorStoreRetrievalRagModule, | |
) | |
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
from griptape.events import ( | |
BaseEvent, | |
EventBus, | |
EventListener, | |
FinishStructureRunEvent, | |
) | |
from griptape.memory.structure import ConversationMemory | |
from griptape.rules import Rule, Ruleset | |
from griptape.structures import Agent, Workflow | |
from griptape.tasks import CodeExecutionTask, StructureRunTask, ToolTask | |
from griptape.tools import RagTool | |
from statemachine import State, StateMachine | |
from statemachine.factory import StateMachineMetaclass | |
from griptape_statemachine.parsers.uw_config_parser import UWConfigParser | |
logger = logging.getLogger(__name__) | |
logging.getLogger("griptape").setLevel(logging.ERROR) | |
if TYPE_CHECKING: | |
from griptape.structures import Structure | |
from griptape.tools import BaseTool | |
from statemachine.event import Event | |
load_dotenv() | |
Defaults.drivers_config = OpenAiDriversConfig( | |
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
) | |
def custom_dict_merge(dict1: dict, dict2: dict) -> dict: | |
result = dict1.copy() | |
for key, value in dict2.items(): | |
if key in result and isinstance(result[key], list) and isinstance(value, list): | |
result[key] = result[key] + value | |
else: | |
result[key] = value | |
return result | |
class UWBaseMachine(StateMachine): | |
"""Base class for a machine. | |
Attributes: | |
config_file (Path): The path to the configuration file. | |
config (dict): The configuration data. | |
outputs_to_user (list[str]): Outputs to return to the user. | |
""" | |
def __init__(self, config_file: Path, **kwargs) -> None: | |
self.config_parser = UWConfigParser(config_file) | |
self.config = self.config_parser.parse() | |
self._structures = {} | |
self.vector_stores = {} # Store here in case needs multiple uses | |
self.question_list: list = [] | |
# For the parameters necessary from the user | |
self.page_range: tuple = () | |
self.question_number: int = 0 | |
self.taxonomy: list = [] | |
self.give_up_count = 0 | |
self.current_question_count = 0 | |
self.state_status: dict[str, bool] = {} | |
for key in self.state_transitions: | |
self.state_status[key] = False | |
def on_event(event: BaseEvent) -> None: | |
"""Takes in griptape events from eventbus and fixes them.""" | |
print(f"Received Griptape event: {json.dumps(event.to_dict(), indent=2)}") | |
try: | |
self.send( | |
"process_event", | |
event_={"type": "griptape_event", "value": event.to_dict()}, | |
) | |
except Exception as e: | |
errormsg = f"Would not allow process_event to be sent. Check to see if it is defined in the config.yaml. Error:{e}" | |
raise ValueError(errormsg) from e | |
EventBus.clear_event_listeners() | |
EventBus.add_event_listener( | |
EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
) | |
super().__init__() | |
def available_events(self) -> list[str]: | |
return self.current_state.transitions.unique_events | |
def tools(self) -> dict[str, BaseTool]: | |
"""Returns the Tools for the machine.""" | |
... | |
def _current_state_config(self) -> dict: | |
return self.config["states"][self.current_state_value] | |
def from_definition( # noqa: C901, PLR0912 | |
cls, definition: dict, **extra_kwargs | |
) -> UWBaseMachine: | |
try: | |
states_instances = {} | |
for state_id, state_kwargs in definition["states"].items(): | |
# These are the relevant states that need GOAP. | |
states_instances[state_id] = State(**state_kwargs, value=state_id) | |
except Exception as e: | |
errormsg = f"""Error in state definition: {e}. | |
""" | |
raise ValueError(errormsg) from e | |
events = {} | |
state_transitions = {} | |
for event_name, transitions in definition["events"].items(): | |
for transition_data in transitions: | |
try: | |
source_name = transition_data["from"] | |
source = states_instances[source_name] | |
target = states_instances[transition_data["to"]] | |
relevance = "" | |
if "relevance" in transition_data: | |
relevance = transition_data["relevance"] | |
if source_name not in state_transitions: | |
state_transitions[source_name] = {event_name: relevance} | |
else: | |
state_transitions[source_name][event_name] = relevance | |
except Exception as e: | |
errormsg = f"Error:{e}. Please check your transitions to be sure each transition has a source and destination." | |
raise ValueError(errormsg) from e | |
transition = source.to( | |
target, | |
event=event_name, | |
cond=transition_data.get("cond"), | |
unless=transition_data.get("unless"), | |
on=transition_data.get("on"), | |
internal=transition_data.get("internal"), | |
) | |
if event_name in events: | |
events[event_name] |= transition | |
else: | |
events[event_name] = transition | |
for state_id, state in states_instances.items(): | |
if state_id not in ("end", "start"): | |
transition = state.to( | |
state, | |
event="process_event", | |
on=f"on_event_{state_id}", | |
internal=True, | |
) | |
if "process_event" in events: | |
events["process_event"] |= transition | |
else: | |
events["process_event"] = transition | |
attrs_mapper = { | |
**extra_kwargs, | |
**states_instances, | |
**events, | |
"state_transitions": state_transitions, | |
} | |
return cast( | |
UWBaseMachine, | |
StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**extra_kwargs), | |
) | |
def from_config_file( | |
cls, | |
config_file: Path, | |
**extra_kwargs, | |
) -> UWBaseMachine: | |
"""Creates a StateMachine class from a configuration file""" | |
config_parser = UWConfigParser(config_file) | |
config = config_parser.parse() | |
extra_kwargs["config_file"] = config_file | |
definition_states = { | |
state_id: { | |
"initial": state_value.get("initial", False), | |
"final": state_value.get("final", False), | |
} | |
for state_id, state_value in config["states"].items() | |
} | |
definition_events = { | |
event_name: list(event_value["transitions"]) | |
for event_name, event_value in config["events"].items() | |
} | |
definition = {"states": definition_states, "events": definition_events} | |
return cls.from_definition(definition, **extra_kwargs) | |
def start_machine(self) -> None: | |
"""Starts the machine.""" | |
... | |
def reset_structures(self) -> None: | |
"""Resets the structures.""" | |
self._structures = {} | |
def on_enter_state(self, source: State, state: State, event: Event) -> None: | |
print(f"Transitioning from {source} to {state} with event {event}") | |
def get_structure(self, structure_id: str) -> Structure: | |
global_structure_config = self.config["structures"][structure_id] | |
state_structure_config = self._current_state_config.get("structures", {}).get( | |
structure_id, {} | |
) | |
structure_config = custom_dict_merge( | |
global_structure_config, state_structure_config | |
) | |
if structure_id not in self._structures: | |
# Initialize Structure with all the expensive setup | |
structure = Agent( | |
id=structure_id, | |
conversation_memory=ConversationMemory(), | |
) | |
self._structures[structure_id] = structure | |
# Create a new clone with state-specific stuff | |
structure = self._structures[structure_id] | |
structure = Agent( | |
id=structure.id, | |
prompt_driver=structure.prompt_driver, | |
conversation_memory=structure.conversation_memory, | |
rulesets=[ | |
*self._get_structure_rulesets(structure_config.get("ruleset_ids", [])), | |
], | |
) | |
print(f"Structure: {structure_id}") | |
for ruleset in structure.rulesets: | |
for rule in ruleset.rules: | |
print(f"Rule: {rule.value}") | |
return structure | |
def _get_structure_rulesets(self, ruleset_ids: list[str]) -> list[Ruleset]: | |
ruleset_configs = [ | |
self.config["rulesets"][ruleset_id] for ruleset_id in ruleset_ids | |
] | |
# Convert ruleset configs to Rulesets | |
return [ | |
Ruleset( | |
name=ruleset_config["name"], | |
rules=[Rule(rule) for rule in ruleset_config["rules"]], | |
) | |
for ruleset_config in ruleset_configs | |
] | |
def get_prompt_by_structure(self, structure_id: str) -> str | None: | |
try: | |
state_structure_config = self._current_state_config.get( | |
"structures", {} | |
).get(structure_id, {}) | |
global_structure_config = self.config["structures"][structure_id] | |
except KeyError: | |
return None | |
prompt_id = None | |
if "prompt_id" in global_structure_config: | |
prompt_id = global_structure_config["prompt_id"] | |
elif "prompt_id" in state_structure_config: | |
prompt_id = state_structure_config["prompt_id"] | |
else: | |
return None | |
return self.config["prompts"][prompt_id]["prompt"] | |
def get_prompt_by_id(self, prompt_id: str) -> str | None: | |
prompt_config = self.config["prompts"] | |
if prompt_id in prompt_config: | |
return prompt_config[prompt_id]["prompt"] | |
return None | |
# ALL METHODS RELATING TO THE WORKFLOW AND PIPELINE | |
def end_workflow(self, task: CodeExecutionTask) -> ListArtifact: | |
parent_outputs = task.parent_outputs | |
questions = [] | |
for outputs in parent_outputs.values(): | |
if outputs.type == "InfoArtifact": | |
continue | |
questions.append(outputs) | |
return ListArtifact(questions) | |
def get_questions_workflow(self) -> Workflow: | |
workflow = Workflow(id="create_question_workflow") | |
# How many questions still need to be created | |
for _ in range(self.question_number - len(self.question_list)): | |
task = StructureRunTask( | |
structure_run_driver=LocalStructureRunDriver( | |
create_structure=self.get_single_question | |
), | |
child_ids=["end_task"], | |
) | |
workflow.add_task(task) | |
end_task = CodeExecutionTask(id="end_task", on_run=self.end_workflow) | |
workflow.add_task(end_task) | |
return workflow | |
def single_question_last_task(self, task: CodeExecutionTask) -> TextArtifact: | |
parent_outputs = task.parent_outputs | |
wrong_answers = parent_outputs["wrong_answers"].value # Output is a list | |
wrong_answers = wrong_answers.split("\n") | |
question_and_answer = parent_outputs["get_question"].value # Output is a json | |
try: | |
question_and_answer = json.loads(question_and_answer) | |
except: | |
question_and_answer = question_and_answer.split("\n")[1:] | |
question_and_answer = "".join(question_and_answer) | |
question_and_answer = json.loads(question_and_answer) | |
inputs = task.input.value.split(",") | |
question = { | |
"Question": question_and_answer["Question"], | |
"Answer": question_and_answer["Answer"], | |
"Wrong Answers": wrong_answers, | |
"Page": inputs[0], | |
"Taxonomy": inputs[1], | |
} | |
return TextArtifact(question) | |
def get_question_for_wrong_answers(self, task: CodeExecutionTask) -> TextArtifact: | |
parent_outputs = task.parent_outputs | |
question = parent_outputs["get_question"].value | |
question = json.loads(question)["Question"] | |
return TextArtifact(question) | |
def get_separated_answer_for_wrong_answers( | |
self, task: CodeExecutionTask | |
) -> TextArtifact: | |
parent_outputs = task.parent_outputs | |
answer = parent_outputs["get_question"].value | |
print(answer) | |
answer = json.loads(answer)["Answer"] | |
return TextArtifact(answer) | |
def make_rag_structure( | |
self, vector_store: GriptapeCloudVectorStoreDriver | |
) -> Structure: | |
if vector_store: | |
tool = self.build_rag_tool(self.build_rag_engine(vector_store)) | |
use_rag_task = ToolTask(tool=tool) | |
return Agent(tasks=[use_rag_task]) | |
errormsg = "No Vector Store" | |
raise ValueError(errormsg) | |
def get_single_question(self) -> Workflow: | |
question_generator = Workflow(id="single_question") | |
taxonomy = random.choice(self.taxonomy) | |
taxonomyprompt = { | |
"Knowledge": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'define', 'list', 'state', 'identify','label'.", | |
"Comprehension": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'explain', 'predict', 'interpret', 'infer', 'summarize', 'convert','give an example of x'.", | |
"Application": "Generate a quiz question based ONLY on this information: {{parent_outputs['information_task']}}, then write the answer to the question. The structure of the question should be randomly chosen from: 'How could x be used to y?', 'How would you show/make use of/modify/demonstrate/solve/apply x to conditions y?'", | |
} | |
pages, driver = self.get_vector_store_id_from_page() | |
get_information = StructureRunTask( | |
id="information_task", | |
input="What is the information in KB?", | |
structure_run_driver=LocalStructureRunDriver( | |
create_structure=lambda: self.make_rag_structure(driver) | |
), | |
child_ids=["get_question"], | |
) | |
# Get KBs and select it, assign it to the structure or create the structure right here. | |
# Rules for subject matter expert: return only a json with question and answer as keys. | |
generate_q_task = StructureRunTask( | |
id="get_question", | |
input=taxonomyprompt[taxonomy], | |
structure_run_driver=LocalStructureRunDriver( | |
create_structure=lambda: self.get_structure("subject_matter_expert") | |
), | |
parent_ids=["information_task"], | |
) | |
get_question_code_task = CodeExecutionTask( | |
id="get_only_question", | |
on_run=self.get_question_for_wrong_answers, | |
parent_ids=["get_question"], | |
child_ids=["wrong_answers"], | |
) | |
get_separated_answer_code_task = CodeExecutionTask( | |
id="get_separated_answer", | |
on_run=self.get_separated_answer_for_wrong_answers, | |
parent_ids=["get_question"], | |
child_ids=["wrong_answers"], | |
) | |
generate_wrong_answers = StructureRunTask( | |
id="wrong_answers", | |
input="""Write and return three incorrect answers for this question: {{parent_outputs['get_separated_question']}}. The correct answer to the question is: {{parent_outputs['get_separated_answer']}}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {{parent_outputs['information_task']}}""", | |
structure_run_driver=LocalStructureRunDriver( | |
create_structure=lambda: self.get_structure("wrong_answers_generator") | |
), | |
parent_ids=["get_only_question", "information_task"], | |
) | |
compile_task = CodeExecutionTask( | |
id="compile_task", | |
input=f"{pages}, {taxonomy}", | |
on_run=self.single_question_last_task, | |
parent_ids=["wrong_answers", "get_question"], | |
) | |
question_generator.add_tasks( | |
get_information, | |
generate_q_task, | |
get_question_code_task, | |
get_separated_answer_code_task, | |
generate_wrong_answers, | |
compile_task, | |
) | |
return question_generator | |
def get_vector_store_id_from_page( | |
self, | |
) -> tuple[str, GriptapeCloudVectorStoreDriver]: | |
base_url = "https://cloud.griptape.ai/api/" | |
kb_url = f"{base_url}/knowledge-bases" | |
headers = {"Authorization": f"Bearer {os.getenv('GT_CLOUD_API_KEY')}"} | |
# TODO: This needs to change when I have my own bucket. Right now, I'm doing the 10 most recently made KBs | |
response = requests.get(url=kb_url, headers=headers) | |
response.raise_for_status() | |
if response.status_code == 200: | |
data = response.json() | |
possible_kbs = {} | |
for kb in data["knowledge_bases"]: | |
name = kb["name"] | |
if "KB_section" not in name: | |
continue | |
page_nums = name.split("p")[1:] | |
start_page = int(page_nums[0].split("-")[0]) | |
end_page = int(page_nums[1]) | |
if end_page <= self.page_range[1] and start_page >= self.page_range[0]: | |
possible_kbs[kb["knowledge_base_id"]] = f"{start_page}-{end_page}" | |
kb_id = random.choice(list(possible_kbs.keys())) | |
page_value = possible_kbs[kb_id] # TODO: This won't help at all actually | |
return page_value, GriptapeCloudVectorStoreDriver( | |
api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
knowledge_base_id=kb_id, | |
) | |
else: | |
raise ValueError(response.status_code) | |
def get_taxonomy_vs(self) -> GriptapeCloudVectorStoreDriver: | |
return GriptapeCloudVectorStoreDriver( | |
api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
knowledge_base_id="2c3a6f19-51a8-43c3-8445-c7fbe06bf460", | |
) | |
def build_rag_engine( | |
self, vector_store_driver: GriptapeCloudVectorStoreDriver | |
) -> RagEngine: | |
return RagEngine( | |
retrieval_stage=RetrievalRagStage( | |
retrieval_modules=[ | |
VectorStoreRetrievalRagModule( | |
vector_store_driver=vector_store_driver, | |
) | |
], | |
), | |
response_stage=ResponseRagStage( | |
response_modules=[TextChunksResponseRagModule()] | |
), | |
) | |
def build_rag_tool(self, engine: RagEngine) -> RagTool: | |
return RagTool( | |
description="Contains information about the textbook. Use it ONLY for context.", | |
rag_engine=engine, | |
) | |