Spaces:
Running
Running
from __future__ import annotations | |
import os | |
import random | |
import schema | |
from typing import TYPE_CHECKING, cast | |
from griptape.configs import Defaults | |
from griptape.configs.drivers import ( | |
OpenAiDriversConfig, | |
) | |
from griptape.drivers import GriptapeCloudVectorStoreDriver, OpenAiChatPromptDriver | |
from griptape.engines.rag import RagEngine | |
from griptape.engines.rag.modules import ( | |
TextChunksResponseRagModule, | |
VectorStoreRetrievalRagModule, | |
) | |
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
from griptape.events import ( | |
BaseEvent, | |
EventBus, | |
EventListener, | |
FinishStructureRunEvent, | |
) | |
from griptape.rules import Rule, Ruleset | |
from griptape.structures import Agent | |
from griptape.tasks import ToolTask | |
from griptape.tools import RagTool | |
from statemachine import State, StateMachine | |
from statemachine.factory import StateMachineMetaclass | |
if TYPE_CHECKING: | |
from griptape.structures import Structure | |
Defaults.drivers_config = OpenAiDriversConfig( | |
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
) | |
# States will be: | |
# random_selection (does dice roll and kb selection plus information task) | |
# Question generation (generates the question and answer properly) | |
# Wrong answer generation (generates a wrong answer?) | |
# Compile task (finishes all and compiles it into a neat thing) | |
# TODO: How to get it to return everything | |
STATES = [ | |
"start", | |
"random_selection", | |
"get_textbook", | |
"question_generation", | |
"wrong_answer_generation", | |
"audit_question", | |
"compile_task", | |
"end", | |
] | |
START = "start" | |
END = "end" | |
TRANSITIONS = [ | |
{ | |
"event": "next_state", | |
"transitions": [ | |
{"from": "random_selection", "to": "get_textbook"}, | |
{"from": "get_textbook", "to": "question_generation"}, | |
{"from": "question_generation", "to": "wrong_answer_generation"}, | |
{"from": "wrong_answer_generation", "to": "audit_question"}, | |
{"from": "audit_question", "to": "compile_task"}, | |
{"from": "compile_task", "to": "end"}, | |
], | |
}, | |
{ | |
"event": "redo", | |
"transitions": [ | |
{"from": "audit_question", "to": "wrong_answer_generation"}, | |
], | |
}, | |
{ | |
"event": "start_up", | |
"transitions": [ | |
{"from": "start", "to": "random_selection"}, | |
], | |
}, | |
{"event": "end_state", "transitions": [{"from": "random_selection", "to": "end"}]}, | |
] | |
RULESETS = { | |
"specific_question_creator": [ | |
"""Question should be a multiple choice quiz style question that assesses a students knowledge of the information in the knowledge base (which should be referred | |
to as 'the textbook'). Answer should be a correct answer to the question that | |
uses information from the knowledge base. Do not return incorrect answers.""", | |
"""The length of the question should be 30 words at most.""", | |
"""Question should never reference or ask about an entire section, never reference | |
or ask about a quote in the knowledge base, never ask for the page number of | |
some information, and never ask for information about the file, document, or | |
knowledge base.""", | |
"""The answer to the question should be short, but should not omit important | |
information.""", | |
], | |
"incorrect_answers_creator": [ | |
"""All incorrect answers should be different, but plausible answers to the question.""", | |
"""Incorrect answers may reference material from the info provided as context, but must not be correct answers to the question""", | |
"""Incorrect answers should always have a similar structure to the correct answer.""", | |
"""The length of all incorrect answers should be as close to the correct answer as possible while remaining plausible.""", | |
], | |
"question_auditor_ruleset": [ | |
# """If any of the rules are false, return false and why. If they are all true, return true.""", | |
"""If any of the rules are false, return True for the part of the question why they are false.""", | |
"""The reason why it is false is between 3-7 words""", | |
"""There is exactly one correct answer.""", | |
"""The correct answer has a clearly distinct meaning from all incorrect answers.""", | |
"""Incorrect answers are plausible to someone who does not know the correct answer""", | |
"""All answer choices are on the same topic as the question""", | |
"""All answer choices are relevant to the context of the question, with no unrelated concepts or entities.""", | |
"""All answers have semantically different meanings from one another, even if they are syntactically similar.""", | |
"""All answer choices are parallel to one another with respect to grammatical structure, length, and complexity""", | |
], | |
} | |
STRUCTURES = { | |
"subject_matter_expert": {"ruleset_ids": ["specific_question_creator"]}, | |
"wrong_answers_generator": {"ruleset_ids": ["incorrect_answers_creator"]}, | |
"question_auditor": {"ruleset_ids": ["question_auditor_ruleset"]}, | |
} | |
class SingleQuestion(StateMachine): | |
"Base class for machine" | |
def __init__(self, **kwargs): | |
self._structures = {} | |
self.kb_ids = kwargs["kb_ids"] | |
self.page_range: tuple = kwargs["page_range"] | |
self.taxonomy_choices: list = kwargs["taxonomy_choices"] | |
self.question: str = "" | |
self.answer: str = "" | |
self.wrong_answers: list = [] | |
self.generated_question: dict = {} | |
self.taxonomy: str = "" | |
self.rejected: bool = False | |
self.give_up: int = 0 | |
self.reject_reason: str = "" | |
def on_event(event: BaseEvent) -> None: | |
"Takes in griptape events and fixes them" | |
try: | |
self.send("griptape_event", event_=event.to_dict()) | |
except Exception as e: | |
errormsg = f"Would not allow Griptape Event to be sent" | |
raise ValueError(errormsg) from e | |
EventBus.clear_event_listeners() | |
EventBus.add_event_listener( | |
EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
) | |
super().__init__() | |
def create_statemachine( | |
cls, taxonomy_choices: list, kb_ids: dict, page_range: tuple | |
) -> SingleQuestion: | |
states_instances = {} | |
events = {} | |
for state in STATES: | |
initial = state == START | |
final = state == END | |
# Creates the states | |
states_instances[state] = State(value=state, initial=initial, final=final) | |
if not (initial or final or state in ("random_selection", "compile_task")): | |
# Creates the internal transition | |
transition = states_instances[state].to( | |
states_instances[state], | |
event="griptape_event", | |
on=f"on_event_{state}", | |
internal=True, | |
) | |
if "griptape_event" in events: | |
events["griptape_event"] |= transition | |
else: | |
events["griptape_event"] = transition | |
for transition in TRANSITIONS: | |
for transition_data in transition["transitions"]: | |
transition_value = states_instances[transition_data["from"]].to( | |
states_instances[transition_data["to"]], | |
event=transition["event"], | |
internal=False, | |
) | |
if transition["event"] in events: | |
events[transition["event"]] |= transition_value | |
else: | |
events[transition["event"]] = transition_value | |
attrs_mapper = { | |
**states_instances, | |
**events, | |
} | |
kwargs = { | |
"taxonomy_choices": taxonomy_choices, | |
"kb_ids": kb_ids, | |
"page_range": page_range, | |
} | |
return cast( | |
SingleQuestion, | |
StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**kwargs), | |
) | |
# BENEATH ARE THE NECESSARY METHODS | |
def on_enter_random_selection(self) -> None: | |
# Get the random taxonomy | |
self.taxonomy = random.choice(self.taxonomy_choices) | |
# I changed this so I didn't have to do an "eval". Not sure how it'll work. | |
taxonomy_prompt = { | |
"Knowledge": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'define', 'list', 'state', 'identify','label'. INFORMATION: ", | |
"Comprehension": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'explain', 'predict', 'interpret', 'infer', 'summarize', 'convert','give an example of x'. INFORMATION: ", | |
"Application": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The structure of the question should be randomly chosen from: 'How could x be used to y?', 'How would you show/make use of/modify/demonstrate/solve/apply x to conditions y?' INFORMATION: ", | |
} | |
self.taxonomy_prompt = taxonomy_prompt[self.taxonomy] | |
# get the random page range and GTCVectorStoreDriver | |
pages, driver = self.get_vector_store_id_from_page() | |
if driver is None: | |
self.send("end_state") | |
self.rejected = True | |
self.reject_reason = "BAD KB PAGE RANGE" | |
print(self.reject_reason) | |
return | |
self.pages = pages | |
self.driver = driver | |
self.send("next_state") | |
def on_enter_get_textbook(self) -> None: | |
# I am going to create the agent in this method | |
if "get_information" not in self._structures: | |
tool = self.build_rag_tool(self.build_rag_engine(self.driver)) | |
use_rag_task = ToolTask(tool=tool) | |
information_retriever = Agent(id="get_information", tasks=[use_rag_task]) | |
self._structures["get_information"] = information_retriever | |
self._structures["get_information"].run("What is the information in KB?") | |
def on_event_get_textbook(self, event_: dict) -> None: | |
event_type = event_["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_["structure_id"] | |
match structure_id: | |
case "get_information": | |
self.information = event_["output_task_output"]["value"] | |
self.send("next_state") | |
def on_enter_question_generation(self) -> None: | |
if "subject_matter_expert" not in self._structures: | |
rulesets = self.get_rulesets("subject_matter_expert") | |
subject_matter_expert = Agent(id="subject_matter_expert", rulesets=rulesets) | |
subject_matter_expert.task.output_schema = schema.Schema( | |
{"Question": str, "Answer": str} | |
) | |
self._structures["subject_matter_expert"] = subject_matter_expert | |
self._structures["subject_matter_expert"].run( | |
f"{self.taxonomy_prompt}{self.information}" | |
) # TODO: Will this work the same as before | |
def on_event_question_generation(self, event_: dict) -> None: | |
event_type = event_["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_["structure_id"] | |
match structure_id: | |
case "subject_matter_expert": | |
question = event_["output_task_output"]["value"] | |
# save question and answer separately | |
self.question = question["Question"] | |
self.answer = question["Answer"] | |
self.send("next_state") | |
def on_enter_wrong_answer_generation(self) -> None: | |
if "wrong_answers_generator" not in self._structures: | |
rulesets = self.get_rulesets("wrong_answers_generator") | |
wrong_answers_generator = Agent( | |
id="wrong_answers_generator", rulesets=rulesets | |
) | |
wrong_answers_generator.task.output_schema = schema.Schema( | |
{"1": str, "2": str, "3": str, "4": str} | |
) | |
self._structures["wrong_answers_generator"] = wrong_answers_generator | |
if not self.rejected: | |
prompt = f"""Write and return four incorrect answers for this question: {self.question}. The correct answer to the question is: {self.answer}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {self.information}""" | |
else: | |
prompt = f"""Write and return four incorrect answers for this question: {self.question}. The correct answer to the question is: {self.answer}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {self.information}. Answers should not be: {self.reject_reason}.""" | |
print(self.reject_reason) | |
self._structures["wrong_answers_generator"].run(prompt) | |
def on_event_wrong_answer_generation(self, event_: dict) -> None: | |
event_type = event_["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_["structure_id"] | |
match structure_id: | |
case "wrong_answers_generator": | |
wrong_answers = event_["output_task_output"]["value"] | |
wrong_answers = [wrong_answers[x] for x in ["1", "2", "3", "4"]] | |
# save question and answer separately | |
self.wrong_answers = wrong_answers | |
self.send("next_state") | |
def on_enter_audit_question(self) -> None: | |
if "question_auditor" not in self._structures: | |
rulesets = self.get_rulesets("question_auditor") | |
question_auditor = Agent(id="question_auditor", rulesets=rulesets) | |
# question_auditor.task.output_schema = schema.Schema( | |
# { | |
# "keep": bool, | |
# "why": schema.Optional( | |
# { | |
# "Question": bool, | |
# "Answer": bool, | |
# "Wrong Answers": bool, | |
# "Reason": str, | |
# } | |
# ), | |
# } | |
# ) | |
question_auditor.task.output_schema = schema.Schema( | |
{ | |
"Bad_Question": bool, | |
"Bad_Answer": bool, | |
"Bad_Wrong_Answers": bool, | |
"Reason": schema.Optional(str), | |
} | |
) | |
self._structures["question_auditor"] = question_auditor | |
# prompt = f"This is the question: {self.question}. This is the answer: {self.answer}. These are the incorrect answers:{self.wrong_answers}. This is the information given:{self.information}. IF the question is not kept, return True for the reason why from 'Question', 'Answers', 'Wrong Answers'." | |
prompt = f"This is the question: {self.question}. This is the answer: {self.answer}. These are the incorrect answers:{self.wrong_answers}. This is the information given:{self.information}. IF the question is should not be kept, return True for the reason why from 'Bad_Question', 'Bad_Answer', 'Bad_Wrong_Answers'." | |
self._structures["question_auditor"].run(prompt) | |
def on_event_audit_question(self, event_: dict) -> None: | |
event_type = event_["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_["structure_id"] | |
match structure_id: | |
case "question_auditor": | |
if self.give_up >= 3: | |
self.rejected = True | |
self.reject_reason += " \n Too many tries" | |
self.send("next_state") | |
return | |
self.give_up += 1 | |
audit = event_["output_task_output"]["value"] | |
# TODO: Go back to some other state that checks the quality bar | |
# if audit["keep"]: | |
# self.send("next_state") | |
# else: | |
# self.rejected = True | |
# self.reject_reason = audit["why"]["Reason"] | |
# if audit["why"]["Question"]: | |
# self.send("next_state") | |
# return | |
# if audit["why"]["Answer"]: | |
# self.send("next_state") | |
# return | |
# if audit["why"]["Wrong Answers"]: | |
# self.send( | |
# "redo" | |
# ) # Goes back to generate more wrong answers | |
# return | |
# self.send("next_state") | |
print(audit) | |
if audit["Bad_Question"]: | |
self.rejected = True | |
self.reject_reason = audit["Reason"] | |
self.reject_classification = "Bad_Question" | |
self.send("next_state") | |
return | |
if audit["Bad_Answer"]: | |
self.rejected = True | |
self.reject_reason = audit["Reason"] | |
self.reject_classification = "Bad_Answer" | |
self.send("next_state") | |
return | |
if audit["Bad_Wrong_Answers"]: | |
self.rejected = True | |
self.reject_reason = audit["Reason"] | |
self.reject_classification = "Bad_Wrong_Answers" | |
self.send("redo") | |
return | |
self.rejected = False | |
self.send("next_state") | |
def on_enter_compile_task(self) -> None: | |
# TODO: Logic to determine if I should go back to wrong answers | |
question = { | |
"Question": self.question, | |
"Answer": self.answer, | |
"Wrong Answers": self.wrong_answers, | |
"Page": self.pages, | |
"Taxonomy": self.taxonomy, | |
} | |
if self.rejected: | |
question["Reject Classification"] = self.reject_classification | |
question["Reason"] = self.reject_reason | |
self.generated_question = question | |
self.send("next_state") | |
# TODO : Does this return output | |
def on_enter_end(self) -> dict: | |
return self.generated_question | |
# HELPER METHODS BELOW | |
def get_rulesets(self, structure_id: str) -> list: | |
final_ruleset_list = [] | |
ruleset_ids = STRUCTURES[structure_id]["ruleset_ids"] | |
for ruleset_id in ruleset_ids: | |
ruleset_rules = RULESETS[ruleset_id] | |
rules = [Rule(rule) for rule in ruleset_rules] | |
final_ruleset_list.append(Ruleset(ruleset_id, rules=rules)) | |
return final_ruleset_list | |
def get_vector_store_id_from_page( | |
self, | |
) -> tuple[str, GriptapeCloudVectorStoreDriver | None]: | |
possible_kbs = {} | |
for name, kb_id in self.kb_ids.items(): | |
page_nums = name.split("p")[1:] | |
start_page = int(page_nums[0].split("-")[0]) | |
end_page = int(page_nums[1]) | |
if end_page <= self.page_range[1] and start_page >= self.page_range[0]: | |
possible_kbs[kb_id] = f"{start_page}-{end_page}" | |
if not len(list(possible_kbs.keys())): | |
return ("No KBs in range", None) | |
kb_id = random.choice(list(possible_kbs.keys())) | |
page_value = possible_kbs[kb_id] | |
return page_value, GriptapeCloudVectorStoreDriver( | |
api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
knowledge_base_id=kb_id, | |
) | |
# Uses this and all below to build the Rag Tool to get information from the KB | |
def build_rag_engine( | |
self, vector_store_driver: GriptapeCloudVectorStoreDriver | |
) -> RagEngine: | |
return RagEngine( | |
retrieval_stage=RetrievalRagStage( | |
retrieval_modules=[ | |
VectorStoreRetrievalRagModule( | |
vector_store_driver=vector_store_driver, | |
) | |
], | |
), | |
response_stage=ResponseRagStage( | |
response_modules=[TextChunksResponseRagModule()] | |
), | |
) | |
def build_rag_tool(self, engine: RagEngine) -> RagTool: | |
return RagTool( | |
description="Contains information about the textbook. Use it ONLY for context.", | |
rag_engine=engine, | |
) | |
def make_rag_structure( | |
self, vector_store: GriptapeCloudVectorStoreDriver | |
) -> Structure: | |
if vector_store: | |
tool = self.build_rag_tool(self.build_rag_engine(vector_store)) | |
use_rag_task = ToolTask(tool=tool) | |
return Agent(tasks=[use_rag_task]) | |
errormsg = "No Vector Store" | |
raise ValueError(errormsg) | |
if __name__ == "__main__": | |
flow = SingleQuestion.create_statemachine( | |
["Comprehension"], | |
{"p126-p129": "9efbb8ab-6a5e-4bca-aab0-7f7500bfb7b5"}, | |
(120, 150), | |
) | |
flow.send("start_up") | |
# When incorporating into the main flow - we can just get the result of flow.generated_question and use that value onwards | |
# TODO: Do any events need to be sent? | |
print(flow.generated_question) | |