kateforsberg's picture
first commit
d477d5c
raw
history blame
3.38 kB
from __future__ import annotations
import schema
import yaml
from attrs import define
from yaml.resolver import Resolver
from griptape_statemachine.parsers.base_parser import BaseParser
STRUCTURE_SCHEMA = schema.Schema(
{
schema.Optional("model"): str,
schema.Optional("ruleset_ids"): [str],
schema.Optional("vector_stores"): [str],
schema.Optional("prompt_id"): str,
}
)
CONFIG_SCHEMA = schema.Schema(
{
"rulesets": schema.Schema(
{
str: schema.Schema(
{
"name": str,
"rules": [str],
}
)
}
),
# Added for vector stores
schema.Optional("vector_stores"): schema.Schema(
{
str: schema.Schema(
{
"file_path": str,
"file_type": str,
schema.Optional("max_tokens"): int,
}
)
}
),
"structures": schema.Schema({str: STRUCTURE_SCHEMA}),
"events": schema.Schema(
{
str: schema.Schema(
{
"transitions": [
schema.Schema(
{
"from": str,
"to": str,
schema.Optional("internal"): bool,
schema.Optional("on"): str,
schema.Optional("relevance"): str,
}
)
],
}
)
}
),
"states": schema.Schema(
{
str: schema.Schema(
{
schema.Optional(
schema.Or("initial", "final")
): bool, # pyright: ignore[reportArgumentType]
schema.Optional("structures"): schema.Schema(
{str: STRUCTURE_SCHEMA}
),
}
)
}
),
schema.Optional("prompts"): {
str: {schema.Optional("author_intent"): str, "prompt": str}
},
}
)
@define()
class UWConfigParser(BaseParser):
def __attrs_post_init__(self) -> None:
# remove resolver entries for On/Off/Yes/No
for ch in "OoYyNn":
if ch in Resolver.yaml_implicit_resolvers:
if len(Resolver.yaml_implicit_resolvers[ch]) == 1:
del Resolver.yaml_implicit_resolvers[ch]
else:
Resolver.yaml_implicit_resolvers[ch] = [
x
for x in Resolver.yaml_implicit_resolvers[ch]
if x[0] != "tag:yaml.org,2002:bool"
]
def parse(self) -> dict:
data = yaml.safe_load(self.file_path.read_text())
CONFIG_SCHEMA.validate(data)
return data
def update_and_save(self, config: dict) -> None:
with self.file_path.open("w") as file:
yaml.dump(config, file, default_flow_style=False, line_break="\n")