Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import random | |
| import schema | |
| from typing import TYPE_CHECKING, cast | |
| from griptape.configs import Defaults | |
| from griptape.configs.drivers import ( | |
| OpenAiDriversConfig, | |
| ) | |
| from griptape.drivers import GriptapeCloudVectorStoreDriver, OpenAiChatPromptDriver | |
| from griptape.engines.rag import RagEngine | |
| from griptape.engines.rag.modules import ( | |
| TextChunksResponseRagModule, | |
| VectorStoreRetrievalRagModule, | |
| ) | |
| from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage | |
| from griptape.events import ( | |
| BaseEvent, | |
| EventBus, | |
| EventListener, | |
| FinishStructureRunEvent, | |
| ) | |
| from griptape.rules import Rule, Ruleset | |
| from griptape.structures import Agent | |
| from griptape.tasks import ToolTask | |
| from griptape.tools import RagTool | |
| from statemachine import State, StateMachine | |
| from statemachine.factory import StateMachineMetaclass | |
| if TYPE_CHECKING: | |
| from griptape.structures import Structure | |
| Defaults.drivers_config = OpenAiDriversConfig( | |
| prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", max_tokens=4096) | |
| ) | |
| # States will be: | |
| # random_selection (does dice roll and kb selection plus information task) | |
| # Question generation (generates the question and answer properly) | |
| # Wrong answer generation (generates a wrong answer?) | |
| # Compile task (finishes all and compiles it into a neat thing) | |
| # TODO: How to get it to return everything | |
| STATES = [ | |
| "start", | |
| "random_selection", | |
| "get_textbook", | |
| "question_generation", | |
| "wrong_answer_generation", | |
| "audit_question", | |
| "compile_task", | |
| "end", | |
| ] | |
| START = "start" | |
| END = "end" | |
| TRANSITIONS = [ | |
| { | |
| "event": "next_state", | |
| "transitions": [ | |
| {"from": "random_selection", "to": "get_textbook"}, | |
| {"from": "get_textbook", "to": "question_generation"}, | |
| {"from": "question_generation", "to": "wrong_answer_generation"}, | |
| {"from": "wrong_answer_generation", "to": "audit_question"}, | |
| {"from": "audit_question", "to": "compile_task"}, | |
| {"from": "compile_task", "to": "end"}, | |
| ], | |
| }, | |
| { | |
| "event": "redo", | |
| "transitions": [ | |
| {"from": "audit_question", "to": "wrong_answer_generation"}, | |
| ], | |
| }, | |
| { | |
| "event": "start_up", | |
| "transitions": [ | |
| {"from": "start", "to": "random_selection"}, | |
| ], | |
| }, | |
| {"event": "end_state", "transitions": [{"from": "random_selection", "to": "end"}]}, | |
| ] | |
| RULESETS = { | |
| "specific_question_creator": [ | |
| """Question should be a multiple choice quiz style question that assesses a students knowledge of the information in the knowledge base (which should be referred | |
| to as 'the textbook'). Answer should be a correct answer to the question that | |
| uses information from the knowledge base. Do not return incorrect answers.""", | |
| """The length of the question should be 30 words at most.""", | |
| """Question should never reference or ask about an entire section, never reference | |
| or ask about a quote in the knowledge base, never ask for the page number of | |
| some information, and never ask for information about the file, document, or | |
| knowledge base.""", | |
| """The answer to the question should be short, but should not omit important | |
| information.""", | |
| ], | |
| "incorrect_answers_creator": [ | |
| """All incorrect answers should be different, but plausible answers to the question.""", | |
| """Incorrect answers may reference material from the info provided as context, but must not be correct answers to the question""", | |
| """Incorrect answers should always have a similar structure to the correct answer.""", | |
| """The length of all incorrect answers should be as close to the correct answer as possible while remaining plausible.""", | |
| ], | |
| "question_auditor_ruleset": [ | |
| # """If any of the rules are false, return false and why. If they are all true, return true.""", | |
| """If any of the rules are false, return True for the part of the question why they are false.""", | |
| """The reason why it is false is between 3-7 words""", | |
| """There is exactly one correct answer.""", | |
| """The correct answer has a clearly distinct meaning from all incorrect answers.""", | |
| """Incorrect answers are plausible to someone who does not know the correct answer""", | |
| """All answer choices are on the same topic as the question""", | |
| """All answer choices are relevant to the context of the question, with no unrelated concepts or entities.""", | |
| """All answers have semantically different meanings from one another, even if they are syntactically similar.""", | |
| """All answer choices are parallel to one another with respect to grammatical structure, length, and complexity""", | |
| ], | |
| } | |
| STRUCTURES = { | |
| "subject_matter_expert": {"ruleset_ids": ["specific_question_creator"]}, | |
| "wrong_answers_generator": {"ruleset_ids": ["incorrect_answers_creator"]}, | |
| "question_auditor": {"ruleset_ids": ["question_auditor_ruleset"]}, | |
| } | |
| class SingleQuestion(StateMachine): | |
| "Base class for machine" | |
| def __init__(self, **kwargs): | |
| self._structures = {} | |
| self.kb_ids = kwargs["kb_ids"] | |
| self.page_range: tuple = kwargs["page_range"] | |
| self.taxonomy_choices: list = kwargs["taxonomy_choices"] | |
| self.question: str = "" | |
| self.answer: str = "" | |
| self.wrong_answers: list = [] | |
| self.generated_question: dict = {} | |
| self.taxonomy: str = "" | |
| self.rejected: bool = False | |
| self.give_up: int = 0 | |
| self.reject_reason: str = "" | |
| def on_event(event: BaseEvent) -> None: | |
| "Takes in griptape events and fixes them" | |
| try: | |
| self.send("griptape_event", event_=event.to_dict()) | |
| except Exception as e: | |
| errormsg = f"Would not allow Griptape Event to be sent" | |
| raise ValueError(errormsg) from e | |
| EventBus.clear_event_listeners() | |
| EventBus.add_event_listener( | |
| EventListener(on_event, event_types=[FinishStructureRunEvent]), | |
| ) | |
| super().__init__() | |
| def create_statemachine( | |
| cls, taxonomy_choices: list, kb_ids: dict, page_range: tuple | |
| ) -> SingleQuestion: | |
| states_instances = {} | |
| events = {} | |
| for state in STATES: | |
| initial = state == START | |
| final = state == END | |
| # Creates the states | |
| states_instances[state] = State(value=state, initial=initial, final=final) | |
| if not (initial or final or state in ("random_selection", "compile_task")): | |
| # Creates the internal transition | |
| transition = states_instances[state].to( | |
| states_instances[state], | |
| event="griptape_event", | |
| on=f"on_event_{state}", | |
| internal=True, | |
| ) | |
| if "griptape_event" in events: | |
| events["griptape_event"] |= transition | |
| else: | |
| events["griptape_event"] = transition | |
| for transition in TRANSITIONS: | |
| for transition_data in transition["transitions"]: | |
| transition_value = states_instances[transition_data["from"]].to( | |
| states_instances[transition_data["to"]], | |
| event=transition["event"], | |
| internal=False, | |
| ) | |
| if transition["event"] in events: | |
| events[transition["event"]] |= transition_value | |
| else: | |
| events[transition["event"]] = transition_value | |
| attrs_mapper = { | |
| **states_instances, | |
| **events, | |
| } | |
| kwargs = { | |
| "taxonomy_choices": taxonomy_choices, | |
| "kb_ids": kb_ids, | |
| "page_range": page_range, | |
| } | |
| return cast( | |
| SingleQuestion, | |
| StateMachineMetaclass(cls.__name__, (cls,), attrs_mapper)(**kwargs), | |
| ) | |
| # BENEATH ARE THE NECESSARY METHODS | |
| def on_enter_random_selection(self) -> None: | |
| # Get the random taxonomy | |
| self.taxonomy = random.choice(self.taxonomy_choices) | |
| # I changed this so I didn't have to do an "eval". Not sure how it'll work. | |
| taxonomy_prompt = { | |
| "Knowledge": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'define', 'list', 'state', 'identify','label'. INFORMATION: ", | |
| "Comprehension": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The interrogative verb for the question should be randomly chosen from: 'explain', 'predict', 'interpret', 'infer', 'summarize', 'convert','give an example of x'. INFORMATION: ", | |
| "Application": "Generate a quiz question based ONLY on the information. Then write the answer to the question. The structure of the question should be randomly chosen from: 'How could x be used to y?', 'How would you show/make use of/modify/demonstrate/solve/apply x to conditions y?' INFORMATION: ", | |
| } | |
| self.taxonomy_prompt = taxonomy_prompt[self.taxonomy] | |
| # get the random page range and GTCVectorStoreDriver | |
| pages, driver = self.get_vector_store_id_from_page() | |
| if driver is None: | |
| self.send("end_state") | |
| self.rejected = True | |
| self.reject_reason = "BAD KB PAGE RANGE" | |
| print(self.reject_reason) | |
| return | |
| self.pages = pages | |
| self.driver = driver | |
| self.send("next_state") | |
| def on_enter_get_textbook(self) -> None: | |
| # I am going to create the agent in this method | |
| if "get_information" not in self._structures: | |
| tool = self.build_rag_tool(self.build_rag_engine(self.driver)) | |
| use_rag_task = ToolTask(tool=tool) | |
| information_retriever = Agent(id="get_information", tasks=[use_rag_task]) | |
| self._structures["get_information"] = information_retriever | |
| self._structures["get_information"].run("What is the information in KB?") | |
| def on_event_get_textbook(self, event_: dict) -> None: | |
| event_type = event_["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_["structure_id"] | |
| match structure_id: | |
| case "get_information": | |
| self.information = event_["output_task_output"]["value"] | |
| self.send("next_state") | |
| def on_enter_question_generation(self) -> None: | |
| if "subject_matter_expert" not in self._structures: | |
| rulesets = self.get_rulesets("subject_matter_expert") | |
| subject_matter_expert = Agent(id="subject_matter_expert", rulesets=rulesets) | |
| subject_matter_expert.task.output_schema = schema.Schema( | |
| {"Question": str, "Answer": str} | |
| ) | |
| self._structures["subject_matter_expert"] = subject_matter_expert | |
| self._structures["subject_matter_expert"].run( | |
| f"{self.taxonomy_prompt}{self.information}" | |
| ) # TODO: Will this work the same as before | |
| def on_event_question_generation(self, event_: dict) -> None: | |
| event_type = event_["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_["structure_id"] | |
| match structure_id: | |
| case "subject_matter_expert": | |
| question = event_["output_task_output"]["value"] | |
| # save question and answer separately | |
| self.question = question["Question"] | |
| self.answer = question["Answer"] | |
| self.send("next_state") | |
| def on_enter_wrong_answer_generation(self) -> None: | |
| if "wrong_answers_generator" not in self._structures: | |
| rulesets = self.get_rulesets("wrong_answers_generator") | |
| wrong_answers_generator = Agent( | |
| id="wrong_answers_generator", rulesets=rulesets | |
| ) | |
| wrong_answers_generator.task.output_schema = schema.Schema( | |
| {"1": str, "2": str, "3": str, "4": str} | |
| ) | |
| self._structures["wrong_answers_generator"] = wrong_answers_generator | |
| if not self.rejected: | |
| prompt = f"""Write and return four incorrect answers for this question: {self.question}. The correct answer to the question is: {self.answer}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {self.information}""" | |
| else: | |
| prompt = f"""Write and return four incorrect answers for this question: {self.question}. The correct answer to the question is: {self.answer}, and incorrect answers should have similar sentence structure to the correct answer. Write the incorrect answers from this information: {self.information}. Answers should not be: {self.reject_reason}.""" | |
| print(self.reject_reason) | |
| self._structures["wrong_answers_generator"].run(prompt) | |
| def on_event_wrong_answer_generation(self, event_: dict) -> None: | |
| event_type = event_["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_["structure_id"] | |
| match structure_id: | |
| case "wrong_answers_generator": | |
| wrong_answers = event_["output_task_output"]["value"] | |
| wrong_answers = [wrong_answers[x] for x in ["1", "2", "3", "4"]] | |
| # save question and answer separately | |
| self.wrong_answers = wrong_answers | |
| self.send("next_state") | |
| def on_enter_audit_question(self) -> None: | |
| if "question_auditor" not in self._structures: | |
| rulesets = self.get_rulesets("question_auditor") | |
| question_auditor = Agent(id="question_auditor", rulesets=rulesets) | |
| # question_auditor.task.output_schema = schema.Schema( | |
| # { | |
| # "keep": bool, | |
| # "why": schema.Optional( | |
| # { | |
| # "Question": bool, | |
| # "Answer": bool, | |
| # "Wrong Answers": bool, | |
| # "Reason": str, | |
| # } | |
| # ), | |
| # } | |
| # ) | |
| question_auditor.task.output_schema = schema.Schema( | |
| { | |
| "Bad_Question": bool, | |
| "Bad_Answer": bool, | |
| "Bad_Wrong_Answers": bool, | |
| "Reason": schema.Optional(str), | |
| } | |
| ) | |
| self._structures["question_auditor"] = question_auditor | |
| # prompt = f"This is the question: {self.question}. This is the answer: {self.answer}. These are the incorrect answers:{self.wrong_answers}. This is the information given:{self.information}. IF the question is not kept, return True for the reason why from 'Question', 'Answers', 'Wrong Answers'." | |
| prompt = f"This is the question: {self.question}. This is the answer: {self.answer}. These are the incorrect answers:{self.wrong_answers}. This is the information given:{self.information}. IF the question is should not be kept, return True for the reason why from 'Bad_Question', 'Bad_Answer', 'Bad_Wrong_Answers'." | |
| self._structures["question_auditor"].run(prompt) | |
| def on_event_audit_question(self, event_: dict) -> None: | |
| event_type = event_["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_["structure_id"] | |
| match structure_id: | |
| case "question_auditor": | |
| if self.give_up >= 3: | |
| self.rejected = True | |
| self.reject_reason += " \n Too many tries" | |
| self.send("next_state") | |
| return | |
| self.give_up += 1 | |
| audit = event_["output_task_output"]["value"] | |
| # TODO: Go back to some other state that checks the quality bar | |
| # if audit["keep"]: | |
| # self.send("next_state") | |
| # else: | |
| # self.rejected = True | |
| # self.reject_reason = audit["why"]["Reason"] | |
| # if audit["why"]["Question"]: | |
| # self.send("next_state") | |
| # return | |
| # if audit["why"]["Answer"]: | |
| # self.send("next_state") | |
| # return | |
| # if audit["why"]["Wrong Answers"]: | |
| # self.send( | |
| # "redo" | |
| # ) # Goes back to generate more wrong answers | |
| # return | |
| # self.send("next_state") | |
| print(audit) | |
| if audit["Bad_Question"]: | |
| self.rejected = True | |
| self.reject_reason = audit["Reason"] | |
| self.reject_classification = "Bad_Question" | |
| self.send("next_state") | |
| return | |
| if audit["Bad_Answer"]: | |
| self.rejected = True | |
| self.reject_reason = audit["Reason"] | |
| self.reject_classification = "Bad_Answer" | |
| self.send("next_state") | |
| return | |
| if audit["Bad_Wrong_Answers"]: | |
| self.rejected = True | |
| self.reject_reason = audit["Reason"] | |
| self.reject_classification = "Bad_Wrong_Answers" | |
| self.send("redo") | |
| return | |
| self.rejected = False | |
| self.send("next_state") | |
| def on_enter_compile_task(self) -> None: | |
| # TODO: Logic to determine if I should go back to wrong answers | |
| question = { | |
| "Question": self.question, | |
| "Answer": self.answer, | |
| "Wrong Answers": self.wrong_answers, | |
| "Page": self.pages, | |
| "Taxonomy": self.taxonomy, | |
| } | |
| if self.rejected: | |
| question["Reject Classification"] = self.reject_classification | |
| question["Reason"] = self.reject_reason | |
| self.generated_question = question | |
| self.send("next_state") | |
| # TODO : Does this return output | |
| def on_enter_end(self) -> dict: | |
| return self.generated_question | |
| # HELPER METHODS BELOW | |
| def get_rulesets(self, structure_id: str) -> list: | |
| final_ruleset_list = [] | |
| ruleset_ids = STRUCTURES[structure_id]["ruleset_ids"] | |
| for ruleset_id in ruleset_ids: | |
| ruleset_rules = RULESETS[ruleset_id] | |
| rules = [Rule(rule) for rule in ruleset_rules] | |
| final_ruleset_list.append(Ruleset(ruleset_id, rules=rules)) | |
| return final_ruleset_list | |
| def get_vector_store_id_from_page( | |
| self, | |
| ) -> tuple[str, GriptapeCloudVectorStoreDriver | None]: | |
| possible_kbs = {} | |
| for name, kb_id in self.kb_ids.items(): | |
| page_nums = name.split("p")[1:] | |
| start_page = int(page_nums[0].split("-")[0]) | |
| end_page = int(page_nums[1]) | |
| if end_page <= self.page_range[1] and start_page >= self.page_range[0]: | |
| possible_kbs[kb_id] = f"{start_page}-{end_page}" | |
| if not len(list(possible_kbs.keys())): | |
| return ("No KBs in range", None) | |
| kb_id = random.choice(list(possible_kbs.keys())) | |
| page_value = possible_kbs[kb_id] | |
| return page_value, GriptapeCloudVectorStoreDriver( | |
| api_key=os.getenv("GT_CLOUD_API_KEY", ""), | |
| knowledge_base_id=kb_id, | |
| ) | |
| # Uses this and all below to build the Rag Tool to get information from the KB | |
| def build_rag_engine( | |
| self, vector_store_driver: GriptapeCloudVectorStoreDriver | |
| ) -> RagEngine: | |
| return RagEngine( | |
| retrieval_stage=RetrievalRagStage( | |
| retrieval_modules=[ | |
| VectorStoreRetrievalRagModule( | |
| vector_store_driver=vector_store_driver, | |
| ) | |
| ], | |
| ), | |
| response_stage=ResponseRagStage( | |
| response_modules=[TextChunksResponseRagModule()] | |
| ), | |
| ) | |
| def build_rag_tool(self, engine: RagEngine) -> RagTool: | |
| return RagTool( | |
| description="Contains information about the textbook. Use it ONLY for context.", | |
| rag_engine=engine, | |
| ) | |
| def make_rag_structure( | |
| self, vector_store: GriptapeCloudVectorStoreDriver | |
| ) -> Structure: | |
| if vector_store: | |
| tool = self.build_rag_tool(self.build_rag_engine(vector_store)) | |
| use_rag_task = ToolTask(tool=tool) | |
| return Agent(tasks=[use_rag_task]) | |
| errormsg = "No Vector Store" | |
| raise ValueError(errormsg) | |
| if __name__ == "__main__": | |
| flow = SingleQuestion.create_statemachine( | |
| ["Comprehension"], | |
| {"p126-p129": "9efbb8ab-6a5e-4bca-aab0-7f7500bfb7b5"}, | |
| (120, 150), | |
| ) | |
| flow.send("start_up") | |
| # When incorporating into the main flow - we can just get the result of flow.generated_question and use that value onwards | |
| # TODO: Do any events need to be sent? | |
| print(flow.generated_question) | |