Spaces:
Sleeping
Sleeping
File size: 6,904 Bytes
d477d5c d699829 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c f685ddc d477d5c 5d4cc46 d477d5c f685ddc 24e48ac 5d4cc46 f685ddc d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c 5d4cc46 d477d5c c56aab6 63e3e29 5d4cc46 c56aab6 5d4cc46 c56aab6 5d4cc46 c56aab6 5d4cc46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from __future__ import annotations
import ast
import csv
import json
from pathlib import Path
import random
from typing import TYPE_CHECKING
from uw_programmatic.base_machine import UWBaseMachine
if TYPE_CHECKING:
from griptape.tools import BaseTool
class UWMachine(UWBaseMachine):
"""State machine with GOAP"""
@property
def tools(self) -> dict[str, BaseTool]:
return {}
def start_machine(self) -> None:
"""Starts the machine."""
# Clear input history.
# Clear csv file
self.retrieve_vector_stores()
self.send("enter_first_state")
# The first state: Listens for Gradio and then gives us the parameters to search for.
# Reinitializes the Give Up counter.
def on_event_gather_parameters(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "user_input":
parameters = event_value
self.page_range = parameters["page_range"]
self.question_number = parameters["question_number"]
self.taxonomy = parameters["taxonomy"]
self.current_question_count = 0
self.give_up_count = 0
self.send("next_state")
case _:
err_msg = f"Unexpected Transition Event ID: {event_value}."
raise ValueError(err_msg)
# Checks if there have not been any new questions generated 3 tries in a row
# If # of questions is the same as the # of questions required - sends to end.
def on_enter_evaluate_q_count(self) -> None:
if len(self.question_list) <= self.current_question_count:
self.give_up_count += 1
else:
self.current_question_count = len(self.question_list)
self.give_up_count = 0
if self.give_up_count >= 3:
self.send("finish_state") # go to output questions
return
if len(self.question_list) >= self.question_number:
self.send("finish_state") # go to output questions
else:
self.send("next_state") # go to need more questions
# Necessary for state machine to not throw errors
def on_event_evaluate_q_count(self, event_: dict) -> None:
pass
def on_enter_need_more_q(self) -> None:
# Create the entire workflow to create another question.
self.get_questions_workflow().run()
# Returns the output of the workflow - a ListArtifact of TextArtifacts of questions.
# Question, Answer, Wrong Answers, Taxonomy, Page Number
def on_event_need_more_q(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "griptape_event":
event_type = event_value["type"]
match event_type:
case "FinishStructureRunEvent":
structure_id = event_value["structure_id"]
match structure_id:
case "create_question_workflow":
values = event_value["output_task_output"]["value"]
questions = [
ast.literal_eval(question["value"])
for question in values
]
self.most_recent_questions = (
questions # This is a ListArtifact
)
self.send("next_state")
case _:
print(f"Error:{event_} ")
case _:
print(f"Unexpected: {event_}")
# Merges the existing and new questions and sends to similarity auditor to get rid of similar questions.
def on_enter_assess_generated_q(self) -> None:
merged_list = [*self.question_list, *self.most_recent_questions]
prompt = f"{merged_list}"
self.get_structure("similarity_auditor").run(prompt)
# Sets the returned question list (with similar questions wiped) equal to self.question_list
def on_event_assess_generated_q(self, event_: dict) -> None:
event_source = event_["type"]
event_value = event_["value"]
match event_source:
case "griptape_event":
event_type = event_value["type"]
match event_type:
case "FinishStructureRunEvent":
structure_id = event_value["structure_id"]
match structure_id:
case "similarity_auditor":
new_question_list = event_value["output_task_output"][
"value"
]
try:
new_question_list = json.loads(
new_question_list
) # This must be in that JSON format
except: # If not in JSON decode format
new_question_list = self.question_list
self.question_list = new_question_list
self.send("next_state") # go to Evaluate Q Count
# Writes and saves a csv in the correct format to outputs/professor_guide.csv
def on_enter_output_q(self) -> None:
with Path(Path.cwd().joinpath("outputs/professor_guide.csv")).open(
"w", newline=""
) as file:
writer = csv.writer(file)
for question in range(len(self.question_list)):
new_row = ["MC", "", 1]
new_row.append(self.question_list[question]["Question"])
wrong_answers = list(self.question_list[question]["Wrong Answers"])
column = random.randint(1, len(wrong_answers) + 1)
new_row.append(column)
for i in range(1, len(wrong_answers) + 2):
if i == column:
new_row.append(self.question_list[question]["Answer"])
else:
new_row.append(wrong_answers.pop())
new_row.append(self.question_list[question]["Page"])
new_row.append(self.question_list[question]["Taxonomy"])
writer.writerow(new_row)
if self.give_up_count == 3:
writer.writerow(
[
"Failed to generate more questions.",
]
)
self.send("next_state") # back to gather_parameters
# Necessary to prevent errors being thrown from state machine
def on_event_output_q(self, event_: dict) -> None:
pass
|