Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import ast | |
import csv | |
import json | |
from pathlib import Path | |
import random | |
from typing import TYPE_CHECKING | |
from uw_programmatic.base_machine import UWBaseMachine | |
if TYPE_CHECKING: | |
from griptape.tools import BaseTool | |
class UWMachine(UWBaseMachine): | |
"""State machine with GOAP""" | |
def tools(self) -> dict[str, BaseTool]: | |
return {} | |
def start_machine(self) -> None: | |
"""Starts the machine.""" | |
# Clear input history. | |
# Clear csv file | |
self.send("enter_first_state") | |
def on_event_gather_parameters(self, event_: dict) -> None: | |
event_source = event_["type"] | |
event_value = event_["value"] | |
match event_source: | |
case "user_input": | |
parameters = event_value | |
self.page_range = parameters["page_range"] | |
self.question_number = parameters["question_number"] | |
self.taxonomy = parameters["taxonomy"] | |
self.current_question_count = 0 | |
self.give_up_count = 0 | |
self.send("next_state") | |
case _: | |
err_msg = f"Unexpected Transition Event ID: {event_value}." | |
raise ValueError(err_msg) | |
def on_enter_evaluate_q_count(self) -> None: | |
# Check if the number of questions has incremented | |
if len(self.question_list) <= self.current_question_count: | |
self.give_up_count += 1 | |
else: | |
self.current_question_count = len(self.question_list) | |
self.give_up_count = 0 | |
if self.give_up_count == 3: | |
self.send("finish_state") | |
return | |
if len(self.question_list) >= self.question_number: | |
self.send("finish_state") # go to output questions | |
else: | |
self.send("next_state") # go to need more questions | |
def on_event_evaluate_q_count(self, event_: dict) -> None: | |
pass | |
def on_enter_need_more_q(self) -> None: | |
# Create the entire workflow to create another question. | |
self.get_questions_workflow().run() | |
def on_event_need_more_q(self, event_: dict) -> None: | |
event_source = event_["type"] | |
event_value = event_["value"] | |
match event_source: | |
case "griptape_event": | |
event_type = event_value["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_value["structure_id"] | |
match structure_id: | |
case "create_question_workflow": | |
values = event_value["output_task_output"]["value"] | |
questions = [ | |
ast.literal_eval(question["value"]) | |
for question in values | |
] | |
self.most_recent_questions = ( | |
questions # This is a ListArtifact I'm pretty sure | |
) | |
self.send("next_state") | |
case _: | |
print(f"Error:{event_} ") | |
case _: | |
print(f"Unexpected: {event_}") | |
def on_enter_assess_generated_q(self) -> None: | |
# TODO: Should it append it to the list already and remove duplicates? or not? | |
# TODO: Merge incoming lists | |
merged_list = [*self.question_list, *self.most_recent_questions] | |
prompt = f"{merged_list}" | |
self.get_structure("similarity_auditor").run(prompt) | |
def on_event_assess_generated_q(self, event_: dict) -> None: | |
event_source = event_["type"] | |
event_value = event_["value"] | |
match event_source: | |
case "griptape_event": | |
event_type = event_value["type"] | |
match event_type: | |
case "FinishStructureRunEvent": | |
structure_id = event_value["structure_id"] | |
match structure_id: | |
case "similarity_auditor": | |
new_question_list = event_value["output_task_output"][ | |
"value" | |
] | |
try: | |
new_question_list = json.loads( | |
new_question_list | |
) # This must be in that JSON format | |
except: | |
new_question_list = self.question_list | |
self.question_list = new_question_list | |
self.send("next_state") # move on | |
def on_enter_output_q(self) -> None: | |
with Path(Path.cwd().joinpath("outputs/professor_guide.csv")).open( | |
"w", newline="" | |
) as file: | |
writer = csv.writer(file) | |
for question in range(len(self.question_list)): | |
new_row = ["MC", "", 1] | |
try: | |
new_row.append(self.question_list[question]["Question"]) | |
wrong_answers = list(self.question_list[question]["Wrong Answers"]) | |
column = random.randint(1, len(wrong_answers) + 1) | |
new_row.append(column) | |
for i in range(1, len(wrong_answers) + 2): | |
if i == column: | |
new_row.append(self.question_list[question]["Answer"]) | |
else: | |
wrong_answer = wrong_answers.pop() | |
if not wrong_answer: | |
wrong_answer = "" | |
new_row.append(wrong_answer) | |
new_row.append(self.question_list[question]["Page"]) | |
new_row.append(self.question_list[question]["Taxonomy"]) | |
writer.writerow(new_row) | |
except KeyError: | |
new_row.append(self.question_list["Question"]) | |
wrong_answers = list(self.question_list["Wrong Answers"]) | |
column = random.randint(1, len(wrong_answers) + 1) | |
new_row.append(column) | |
for i in range(1, len(wrong_answers) + 2): | |
if i == column: | |
new_row.append(self.question_list["Answer"]) | |
else: | |
new_row.append(wrong_answers.pop()) | |
new_row.append(self.question_list["Page"]) | |
new_row.append(self.question_list["Taxonomy"]) | |
writer.writerow(new_row) | |
if self.give_up_count == 3: | |
writer.writerow("Failed to generate more questions.") | |
self.send("next_state") | |
def on_event_output_q(self, event_: dict) -> None: | |
pass | |
def on_exit_output_q(self) -> None: | |
# Reset the state machine values | |
self.question_list = [] | |
self.most_recent_questions = [] | |
if __name__ == "__main__": | |
question_list = [ | |
{ | |
"Page": "1-2", | |
"Taxonomy": "Knowledge", | |
"Question": "What is Python?", | |
"Answer": "A programming language", | |
"Wrong Answers": ["A snake", "A car brand", "A fruit"], | |
}, | |
{ | |
"Page": "3-4", | |
"Taxonomy": "Comprehension", | |
"Question": "What does HTML stand for?", | |
"Answer": "HyperText Markup Language", | |
"Wrong Answers": [ | |
"High Text Machine Language", | |
"Hyperlink Text Mode Language", | |
"None of the above", | |
], | |
}, | |
] | |
with Path(Path.cwd().joinpath("outputs/professor_guide.csv")).open( | |
"w", newline="" | |
) as file: | |
writer = csv.writer(file) | |
for question in range(len(question_list)): | |
# TODO: Shuffle answers according to row, keep correct answer in random section. Answer column is a number. | |
new_row = [question_list[question]["Question"]] | |
wrong_answers = list(question_list[question]["Wrong Answers"]) | |
column = random.randint(1, len(wrong_answers) + 1) | |
new_row.append(column) | |
for i in range(1, len(wrong_answers) + 2): | |
if i == column: | |
new_row.append(question_list[question]["Answer"]) | |
else: | |
new_row.append(wrong_answers.pop()) | |
new_row.append(question_list[question]["Page"]) | |
new_row.append(question_list[question]["Taxonomy"]) | |
writer.writerow(new_row) | |