import contextlib import re def extract_input_variables(nodes): """Extracts input variables from the template and adds them to the input_variables field.""" for node in nodes: with contextlib.suppress(Exception): if "input_variables" in node["data"]["node"]["template"]: if node["data"]["node"]["template"]["_type"] == "prompt": variables = re.findall( r"\{(.*?)\}", node["data"]["node"]["template"]["template"]["value"], ) elif node["data"]["node"]["template"]["_type"] == "few_shot": variables = re.findall( r"\{(.*?)\}", node["data"]["node"]["template"]["prefix"]["value"] + node["data"]["node"]["template"]["suffix"]["value"], ) else: variables = [] node["data"]["node"]["template"]["input_variables"]["value"] = variables return nodes def get_root_vertex(graph): """Returns the root node of the template.""" incoming_edges = {edge.source_id for edge in graph.edges} if not incoming_edges and len(graph.vertices) == 1: return graph.vertices[0] return next((node for node in graph.vertices if node.id not in incoming_edges), None) def build_json(root, graph) -> dict: if "node" not in root.data: # If the root node has no "node" key, then it has only one child, # which is the target of the single outgoing edge edge = root.edges[0] local_nodes = [edge.target] else: # Otherwise, find all children whose type matches the type # specified in the template node_type = root.node_type local_nodes = graph.get_nodes_with_target(root) if len(local_nodes) == 1: return build_json(local_nodes[0], graph) # Build a dictionary from the template template = root.data["node"]["template"] final_dict = template.copy() for key in final_dict: if key == "_type": continue value = final_dict[key] node_type = value["type"] if "value" in value and value["value"] is not None: # If the value is specified, use it value = value["value"] elif "dict" in node_type: # If the value is a dictionary, create an empty dictionary value = {} else: # Otherwise, recursively build the child nodes children = [] for local_node in local_nodes: node_children = graph.get_children_by_node_type(local_node, node_type) children.extend(node_children) if value["required"] and not children: msg = f"No child with type {node_type} found" raise ValueError(msg) values = [build_json(child, graph) for child in children] value = ( list(values) if value["list"] else next(iter(values), None) # type: ignore[arg-type] ) final_dict[key] = value return final_dict