kateforsberg's picture
prevents error from crashing hte entire site
25d18e3
raw
history blame
7.73 kB
from __future__ import annotations
import ast
import schema
import csv
import json
from pathlib import Path
import random
from typing import TYPE_CHECKING
from uw_programmatic.base_machine import UWBaseMachine
if TYPE_CHECKING:
from griptape.tools import BaseTool
class UWMachine(UWBaseMachine):
"""State machine with GOAP"""
@property
def tools(self) -> dict[str, BaseTool]:
return {}
def start_machine(self) -> None:
"""Starts the machine."""
# Clear input history.
# Clear csv file
self.retrieve_vector_stores()
self.send("enter_first_state")
def on_enter_gather_parameters(self) -> None:
# Reinitialzes the state machine
self.current_question_count = 0
self.give_up_count = 0
self.question_list = []
self.rejected_questions = []
# The first state: Listens for Gradio and then gives us the parameters to search for.
# Reinitializes the Give Up counter.
def on_event_gather_parameters(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "user_input":
parameters = event_value
self.page_range = parameters["page_range"]
self.question_number = parameters["question_number"]
self.taxonomy = parameters["taxonomy"]
self.errored = False
self.send("next_state")
case "griptape_event":
if event_value["structure_id"] == "create_question_workflow":
pass
case _:
err_msg = f"Unexpected Transition Event ID: {event_value}."
raise ValueError(err_msg)
# Checks if there have not been any new questions generated 3 tries in a row
# If # of questions is the same as the # of questions required - sends to end.
def on_enter_evaluate_q_count(self) -> None:
if len(self.question_list) <= self.current_question_count:
self.give_up_count += 1
else:
self.current_question_count = len(self.question_list)
self.give_up_count = 0
if self.give_up_count >= 3:
self.send("finish_state") # go to output questions
return
if len(self.question_list) >= self.question_number:
self.send("finish_state") # go to output questions
else:
self.send("next_state") # go to need more questions
# Necessary for state machine to not throw errors
def on_event_evaluate_q_count(self, event_: dict) -> None:
pass
def on_enter_need_more_q(self) -> None:
# Create the entire workflow to create another question.
self.get_questions_workflow().run()
# Returns the output of the workflow - a ListArtifact of TextArtifacts of questions.
# Question, Answer, Wrong Answers, Taxonomy, Page Number
def on_event_need_more_q(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "griptape_event":
event_type = event_value["type"]
match event_type:
case "FinishStructureRunEvent":
structure_id = event_value["structure_id"]
match structure_id:
case "create_question_workflow":
# TODO: Can you use task.output_schema on a workflow?
values = event_value["output_task_output"]["value"]
questions = [
ast.literal_eval(question["value"])
for question in values
]
self.most_recent_questions = (
questions # This is a ListArtifact
)
self.send("next_state")
case _:
print(f"Error:{event_} ")
case _:
print(f"Unexpected: {event_}")
# Merges the existing and new questions and sends to similarity auditor to get rid of similar questions.
def on_enter_assess_generated_q(self) -> None:
merged_list = [*self.question_list, *self.most_recent_questions]
prompt = f"{merged_list}"
similarity_auditor = self.get_structure("similarity_auditor")
similarity_auditor.task.output_schema = schema.Schema(
{
"list": schema.Schema(
[
{
"Question": str,
"Answer": str,
"Wrong Answers": schema.Schema([str]),
"Page": str,
"Taxonomy": str,
}
]
)
}
)
similarity_auditor.run(prompt)
# Sets the returned question list (with similar questions wiped) equal to self.question_list
def on_event_assess_generated_q(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "griptape_event":
event_type = event_value["type"]
match event_type:
case "FinishStructureRunEvent":
structure_id = event_value["structure_id"]
match structure_id:
case "similarity_auditor":
new_question_list = event_value["output_task_output"][
"value"
]["list"]
self.question_list = new_question_list
self.send("next_state") # go to Evaluate Q Count
# Writes and saves a csv in the correct format to outputs/professor_guide.csv
def on_enter_output_q(self) -> None:
file_path = Path.cwd().joinpath("outputs/professor_guide.csv")
file_path.parent.mkdir(parents=True, exist_ok=True)
with file_path.open("w+", newline="") as file:
writer = csv.writer(file)
for question in self.question_list:
new_row = ["MC", "", 1]
new_row.append(question["Question"])
wrong_answers = list(question["Wrong Answers"])
column = random.randint(1, len(wrong_answers) + 1)
new_row.append(column)
for i in range(1, len(wrong_answers) + 2):
if i == column:
new_row.append(question["Answer"])
else:
new_row.append(wrong_answers.pop())
new_row.append(question["Page"])
new_row.append(question["Taxonomy"])
writer.writerow(new_row)
if self.give_up_count == 3:
writer.writerow(
[
"Failed to generate more questions.",
]
)
rejected_path = Path.cwd().joinpath("outputs/rejected_list.csv")
with rejected_path.open("w+", newline="") as rejected_file:
writer = csv.writer(rejected_file)
for question in self.rejected_questions:
writer.writerow(question.values())
self.send("next_state") # back to gather_parameters
# Necessary to prevent errors being thrown from state machine
def on_event_output_q(self, event_: dict) -> None:
pass