Tai Truong
fix readme
d202ada
import contextlib
import re
def extract_input_variables(nodes):
"""Extracts input variables from the template and adds them to the input_variables field."""
for node in nodes:
with contextlib.suppress(Exception):
if "input_variables" in node["data"]["node"]["template"]:
if node["data"]["node"]["template"]["_type"] == "prompt":
variables = re.findall(
r"\{(.*?)\}",
node["data"]["node"]["template"]["template"]["value"],
)
elif node["data"]["node"]["template"]["_type"] == "few_shot":
variables = re.findall(
r"\{(.*?)\}",
node["data"]["node"]["template"]["prefix"]["value"]
+ node["data"]["node"]["template"]["suffix"]["value"],
)
else:
variables = []
node["data"]["node"]["template"]["input_variables"]["value"] = variables
return nodes
def get_root_vertex(graph):
"""Returns the root node of the template."""
incoming_edges = {edge.source_id for edge in graph.edges}
if not incoming_edges and len(graph.vertices) == 1:
return graph.vertices[0]
return next((node for node in graph.vertices if node.id not in incoming_edges), None)
def build_json(root, graph) -> dict:
if "node" not in root.data:
# If the root node has no "node" key, then it has only one child,
# which is the target of the single outgoing edge
edge = root.edges[0]
local_nodes = [edge.target]
else:
# Otherwise, find all children whose type matches the type
# specified in the template
node_type = root.node_type
local_nodes = graph.get_nodes_with_target(root)
if len(local_nodes) == 1:
return build_json(local_nodes[0], graph)
# Build a dictionary from the template
template = root.data["node"]["template"]
final_dict = template.copy()
for key in final_dict:
if key == "_type":
continue
value = final_dict[key]
node_type = value["type"]
if "value" in value and value["value"] is not None:
# If the value is specified, use it
value = value["value"]
elif "dict" in node_type:
# If the value is a dictionary, create an empty dictionary
value = {}
else:
# Otherwise, recursively build the child nodes
children = []
for local_node in local_nodes:
node_children = graph.get_children_by_node_type(local_node, node_type)
children.extend(node_children)
if value["required"] and not children:
msg = f"No child with type {node_type} found"
raise ValueError(msg)
values = [build_json(child, graph) for child in children]
value = (
list(values) if value["list"] else next(iter(values), None) # type: ignore[arg-type]
)
final_dict[key] = value
return final_dict