import ast import contextlib import inspect import traceback from itertools import starmap from pathlib import Path from typing import Any from cachetools import TTLCache, keys from fastapi import HTTPException from loguru import logger from langflow.custom.eval import eval_custom_component_code from langflow.custom.schema import CallableCodeDetails, ClassCodeDetails, MissingDefault class CodeSyntaxError(HTTPException): pass def get_data_type(): from langflow.field_typing import Data return Data def find_class_ast_node(class_obj): """Finds the AST node corresponding to the given class object.""" # Get the source file where the class is defined source_file = inspect.getsourcefile(class_obj) if not source_file: return None, [] # Read the source code from the file source_code = Path(source_file).read_text(encoding="utf-8") # Parse the source code into an AST tree = ast.parse(source_code) # Search for the class definition node in the AST class_node = None import_nodes = [] for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__: class_node = node elif isinstance(node, ast.Import | ast.ImportFrom): import_nodes.append(node) return class_node, import_nodes def imports_key(*args, **kwargs): imports = kwargs.pop("imports") key = keys.methodkey(*args, **kwargs) key += tuple(imports) return key class CodeParser: """A parser for Python source code, extracting code details.""" def __init__(self, code: str | type) -> None: """Initializes the parser with the provided code.""" self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60) if isinstance(code, type): if not inspect.isclass(code): msg = "The provided code must be a class." raise ValueError(msg) # If the code is a class, get its source code code = inspect.getsource(code) self.code = code self.data: dict[str, Any] = { "imports": [], "functions": [], "classes": [], "global_vars": [], } self.handlers = { ast.Import: self.parse_imports, ast.ImportFrom: self.parse_imports, ast.FunctionDef: self.parse_functions, ast.ClassDef: self.parse_classes, ast.Assign: self.parse_global_vars, } def get_tree(self): """Parses the provided code to validate its syntax. It tries to parse the code into an abstract syntax tree (AST). """ try: tree = ast.parse(self.code) except SyntaxError as err: raise CodeSyntaxError( status_code=400, detail={"error": err.msg, "traceback": traceback.format_exc()}, ) from err return tree def parse_node(self, node: ast.stmt | ast.AST) -> None: """Parses an AST node and updates the data dictionary with the relevant information.""" if handler := self.handlers.get(type(node)): handler(node) # type: ignore[operator] def parse_imports(self, node: ast.Import | ast.ImportFrom) -> None: """Extracts "imports" from the code, including aliases.""" if isinstance(node, ast.Import): for alias in node.names: if alias.asname: self.data["imports"].append(f"{alias.name} as {alias.asname}") else: self.data["imports"].append(alias.name) elif isinstance(node, ast.ImportFrom): for alias in node.names: if alias.asname: self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}")) else: self.data["imports"].append((node.module, alias.name)) def parse_functions(self, node: ast.FunctionDef) -> None: """Extracts "functions" from the code.""" self.data["functions"].append(self.parse_callable_details(node)) def parse_arg(self, arg, default): """Parses an argument and its default value.""" arg_dict = {"name": arg.arg, "default": default} if arg.annotation: arg_dict["type"] = ast.unparse(arg.annotation) return arg_dict # @cachedmethod(operator.attrgetter("cache")) def construct_eval_env(self, return_type_str: str, imports) -> dict: """Constructs an evaluation environment. Constructs an evaluation environment with the necessary imports for the return type, taking into account module aliases. """ eval_env: dict = {} for import_entry in imports: if isinstance(import_entry, tuple): # from module import name module, name = import_entry if name in return_type_str: exec(f"import {module}", eval_env) exec(f"from {module} import {name}", eval_env) else: # import module module = import_entry alias = None if " as " in module: module, alias = module.split(" as ") if module in return_type_str or (alias and alias in return_type_str): exec(f"import {module} as {alias or module}", eval_env) return eval_env def parse_callable_details(self, node: ast.FunctionDef) -> dict[str, Any]: """Extracts details from a single function or method node.""" return_type = None if node.returns: return_type_str = ast.unparse(node.returns) eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"])) # Handle cases where the type is not found in the constructed environment with contextlib.suppress(NameError): return_type = eval(return_type_str, eval_env) # noqa: S307 func = CallableCodeDetails( name=node.name, doc=ast.get_docstring(node), args=self.parse_function_args(node), body=self.parse_function_body(node), return_type=return_type, has_return=self.parse_return_statement(node), ) return func.model_dump() def parse_function_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: """Parses the arguments of a function or method node.""" args = [] args += self.parse_positional_args(node) args += self.parse_varargs(node) args += self.parse_keyword_args(node) # Commented out because we don't want kwargs # showing up as fields in the frontend args += self.parse_kwargs(node) return args def parse_positional_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: """Parses the positional arguments of a function or method node.""" num_args = len(node.args.args) num_defaults = len(node.args.defaults) num_missing_defaults = num_args - num_defaults missing_defaults = [MissingDefault()] * num_missing_defaults default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults] # Now check all default values to see if there # are any "None" values in the middle default_values = [None if value == "None" else value for value in default_values] defaults = missing_defaults + default_values return list(starmap(self.parse_arg, zip(node.args.args, defaults, strict=True))) def parse_varargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]: """Parses the *args argument of a function or method node.""" args = [] if node.args.vararg: args.append(self.parse_arg(node.args.vararg, None)) return args def parse_keyword_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: """Parses the keyword-only arguments of a function or method node.""" kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [ ast.unparse(default) if default else None for default in node.args.kw_defaults ] return list(starmap(self.parse_arg, zip(node.args.kwonlyargs, kw_defaults, strict=True))) def parse_kwargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]: """Parses the **kwargs argument of a function or method node.""" args = [] if node.args.kwarg: args.append(self.parse_arg(node.args.kwarg, None)) return args def parse_function_body(self, node: ast.FunctionDef) -> list[str]: """Parses the body of a function or method node.""" return [ast.unparse(line) for line in node.body] def parse_return_statement(self, node: ast.FunctionDef) -> bool: """Parses the return statement of a function or method node, including nested returns.""" def has_return(node): if isinstance(node, ast.Return): return True if isinstance(node, ast.If): return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse) if isinstance(node, ast.Try): return ( any(has_return(child) for child in node.body) or any(has_return(child) for child in node.handlers) or any(has_return(child) for child in node.finalbody) ) if isinstance(node, ast.For | ast.While): return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse) if isinstance(node, ast.With): return any(has_return(child) for child in node.body) return False return any(has_return(child) for child in node.body) def parse_assign(self, stmt): """Parses an Assign statement and returns a dictionary with the target's name and value.""" for target in stmt.targets: if isinstance(target, ast.Name): return {"name": target.id, "value": ast.unparse(stmt.value)} return None def parse_ann_assign(self, stmt): """Parses an AnnAssign statement and returns a dictionary with the target's name, value, and annotation.""" if isinstance(stmt.target, ast.Name): return { "name": stmt.target.id, "value": ast.unparse(stmt.value) if stmt.value else None, "annotation": ast.unparse(stmt.annotation), } return None def parse_function_def(self, stmt): """Parse a FunctionDef statement. Parse a FunctionDef statement and return the parsed method and a boolean indicating if it's an __init__ method. """ method = self.parse_callable_details(stmt) return (method, True) if stmt.name == "__init__" else (method, False) def get_base_classes(self): """Returns the base classes of the custom component class.""" try: bases = self.execute_and_inspect_classes(self.code) except Exception: # If the code cannot be executed, return an empty list bases = [] raise return bases def parse_classes(self, node: ast.ClassDef) -> None: """Extracts "classes" from the code, including inheritance and init methods.""" bases = self.get_base_classes() nodes = [] for base in bases: if base.__name__ == node.name or base.__name__ in {"CustomComponent", "Component", "BaseComponent"}: continue try: class_node, import_nodes = find_class_ast_node(base) if class_node is None: continue for import_node in import_nodes: self.parse_imports(import_node) nodes.append(class_node) except Exception: # noqa: BLE001 logger.exception("Error finding base class node") nodes.insert(0, node) class_details = ClassCodeDetails( name=node.name, doc=ast.get_docstring(node), bases=[b.__name__ for b in bases], attributes=[], methods=[], init=None, ) for _node in nodes: self.process_class_node(_node, class_details) self.data["classes"].append(class_details.model_dump()) def process_class_node(self, node, class_details) -> None: for stmt in node.body: if isinstance(stmt, ast.Assign): if attr := self.parse_assign(stmt): class_details.attributes.append(attr) elif isinstance(stmt, ast.AnnAssign): if attr := self.parse_ann_assign(stmt): class_details.attributes.append(attr) elif isinstance(stmt, ast.FunctionDef | ast.AsyncFunctionDef): method, is_init = self.parse_function_def(stmt) if is_init: class_details.init = method else: class_details.methods.append(method) def parse_global_vars(self, node: ast.Assign) -> None: """Extracts global variables from the code.""" global_var = { "targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets], "value": ast.unparse(node.value), } self.data["global_vars"].append(global_var) def execute_and_inspect_classes(self, code: str): custom_component_class = eval_custom_component_code(code) custom_component = custom_component_class(_code=code) dunder_class = custom_component.__class__ # Get the base classes at two levels of inheritance bases = [] for base in dunder_class.__bases__: bases.append(base) bases.extend(base.__bases__) return bases def parse_code(self) -> dict[str, Any]: """Runs all parsing operations and returns the resulting data.""" tree = self.get_tree() for node in ast.walk(tree): self.parse_node(node) return self.data