Tai Truong
fix readme
d202ada
import ast
import contextlib
import inspect
import traceback
from itertools import starmap
from pathlib import Path
from typing import Any
from cachetools import TTLCache, keys
from fastapi import HTTPException
from loguru import logger
from langflow.custom.eval import eval_custom_component_code
from langflow.custom.schema import CallableCodeDetails, ClassCodeDetails, MissingDefault
class CodeSyntaxError(HTTPException):
pass
def get_data_type():
from langflow.field_typing import Data
return Data
def find_class_ast_node(class_obj):
"""Finds the AST node corresponding to the given class object."""
# Get the source file where the class is defined
source_file = inspect.getsourcefile(class_obj)
if not source_file:
return None, []
# Read the source code from the file
source_code = Path(source_file).read_text(encoding="utf-8")
# Parse the source code into an AST
tree = ast.parse(source_code)
# Search for the class definition node in the AST
class_node = None
import_nodes = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
class_node = node
elif isinstance(node, ast.Import | ast.ImportFrom):
import_nodes.append(node)
return class_node, import_nodes
def imports_key(*args, **kwargs):
imports = kwargs.pop("imports")
key = keys.methodkey(*args, **kwargs)
key += tuple(imports)
return key
class CodeParser:
"""A parser for Python source code, extracting code details."""
def __init__(self, code: str | type) -> None:
"""Initializes the parser with the provided code."""
self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60)
if isinstance(code, type):
if not inspect.isclass(code):
msg = "The provided code must be a class."
raise ValueError(msg)
# If the code is a class, get its source code
code = inspect.getsource(code)
self.code = code
self.data: dict[str, Any] = {
"imports": [],
"functions": [],
"classes": [],
"global_vars": [],
}
self.handlers = {
ast.Import: self.parse_imports,
ast.ImportFrom: self.parse_imports,
ast.FunctionDef: self.parse_functions,
ast.ClassDef: self.parse_classes,
ast.Assign: self.parse_global_vars,
}
def get_tree(self):
"""Parses the provided code to validate its syntax.
It tries to parse the code into an abstract syntax tree (AST).
"""
try:
tree = ast.parse(self.code)
except SyntaxError as err:
raise CodeSyntaxError(
status_code=400,
detail={"error": err.msg, "traceback": traceback.format_exc()},
) from err
return tree
def parse_node(self, node: ast.stmt | ast.AST) -> None:
"""Parses an AST node and updates the data dictionary with the relevant information."""
if handler := self.handlers.get(type(node)):
handler(node) # type: ignore[operator]
def parse_imports(self, node: ast.Import | ast.ImportFrom) -> None:
"""Extracts "imports" from the code, including aliases."""
if isinstance(node, ast.Import):
for alias in node.names:
if alias.asname:
self.data["imports"].append(f"{alias.name} as {alias.asname}")
else:
self.data["imports"].append(alias.name)
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
if alias.asname:
self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}"))
else:
self.data["imports"].append((node.module, alias.name))
def parse_functions(self, node: ast.FunctionDef) -> None:
"""Extracts "functions" from the code."""
self.data["functions"].append(self.parse_callable_details(node))
def parse_arg(self, arg, default):
"""Parses an argument and its default value."""
arg_dict = {"name": arg.arg, "default": default}
if arg.annotation:
arg_dict["type"] = ast.unparse(arg.annotation)
return arg_dict
# @cachedmethod(operator.attrgetter("cache"))
def construct_eval_env(self, return_type_str: str, imports) -> dict:
"""Constructs an evaluation environment.
Constructs an evaluation environment with the necessary imports for the return type,
taking into account module aliases.
"""
eval_env: dict = {}
for import_entry in imports:
if isinstance(import_entry, tuple): # from module import name
module, name = import_entry
if name in return_type_str:
exec(f"import {module}", eval_env)
exec(f"from {module} import {name}", eval_env)
else: # import module
module = import_entry
alias = None
if " as " in module:
module, alias = module.split(" as ")
if module in return_type_str or (alias and alias in return_type_str):
exec(f"import {module} as {alias or module}", eval_env)
return eval_env
def parse_callable_details(self, node: ast.FunctionDef) -> dict[str, Any]:
"""Extracts details from a single function or method node."""
return_type = None
if node.returns:
return_type_str = ast.unparse(node.returns)
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
# Handle cases where the type is not found in the constructed environment
with contextlib.suppress(NameError):
return_type = eval(return_type_str, eval_env) # noqa: S307
func = CallableCodeDetails(
name=node.name,
doc=ast.get_docstring(node),
args=self.parse_function_args(node),
body=self.parse_function_body(node),
return_type=return_type,
has_return=self.parse_return_statement(node),
)
return func.model_dump()
def parse_function_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""Parses the arguments of a function or method node."""
args = []
args += self.parse_positional_args(node)
args += self.parse_varargs(node)
args += self.parse_keyword_args(node)
# Commented out because we don't want kwargs
# showing up as fields in the frontend
args += self.parse_kwargs(node)
return args
def parse_positional_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""Parses the positional arguments of a function or method node."""
num_args = len(node.args.args)
num_defaults = len(node.args.defaults)
num_missing_defaults = num_args - num_defaults
missing_defaults = [MissingDefault()] * num_missing_defaults
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
# Now check all default values to see if there
# are any "None" values in the middle
default_values = [None if value == "None" else value for value in default_values]
defaults = missing_defaults + default_values
return list(starmap(self.parse_arg, zip(node.args.args, defaults, strict=True)))
def parse_varargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""Parses the *args argument of a function or method node."""
args = []
if node.args.vararg:
args.append(self.parse_arg(node.args.vararg, None))
return args
def parse_keyword_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""Parses the keyword-only arguments of a function or method node."""
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
ast.unparse(default) if default else None for default in node.args.kw_defaults
]
return list(starmap(self.parse_arg, zip(node.args.kwonlyargs, kw_defaults, strict=True)))
def parse_kwargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""Parses the **kwargs argument of a function or method node."""
args = []
if node.args.kwarg:
args.append(self.parse_arg(node.args.kwarg, None))
return args
def parse_function_body(self, node: ast.FunctionDef) -> list[str]:
"""Parses the body of a function or method node."""
return [ast.unparse(line) for line in node.body]
def parse_return_statement(self, node: ast.FunctionDef) -> bool:
"""Parses the return statement of a function or method node, including nested returns."""
def has_return(node):
if isinstance(node, ast.Return):
return True
if isinstance(node, ast.If):
return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse)
if isinstance(node, ast.Try):
return (
any(has_return(child) for child in node.body)
or any(has_return(child) for child in node.handlers)
or any(has_return(child) for child in node.finalbody)
)
if isinstance(node, ast.For | ast.While):
return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse)
if isinstance(node, ast.With):
return any(has_return(child) for child in node.body)
return False
return any(has_return(child) for child in node.body)
def parse_assign(self, stmt):
"""Parses an Assign statement and returns a dictionary with the target's name and value."""
for target in stmt.targets:
if isinstance(target, ast.Name):
return {"name": target.id, "value": ast.unparse(stmt.value)}
return None
def parse_ann_assign(self, stmt):
"""Parses an AnnAssign statement and returns a dictionary with the target's name, value, and annotation."""
if isinstance(stmt.target, ast.Name):
return {
"name": stmt.target.id,
"value": ast.unparse(stmt.value) if stmt.value else None,
"annotation": ast.unparse(stmt.annotation),
}
return None
def parse_function_def(self, stmt):
"""Parse a FunctionDef statement.
Parse a FunctionDef statement and return the parsed method and a boolean indicating if it's an __init__ method.
"""
method = self.parse_callable_details(stmt)
return (method, True) if stmt.name == "__init__" else (method, False)
def get_base_classes(self):
"""Returns the base classes of the custom component class."""
try:
bases = self.execute_and_inspect_classes(self.code)
except Exception:
# If the code cannot be executed, return an empty list
bases = []
raise
return bases
def parse_classes(self, node: ast.ClassDef) -> None:
"""Extracts "classes" from the code, including inheritance and init methods."""
bases = self.get_base_classes()
nodes = []
for base in bases:
if base.__name__ == node.name or base.__name__ in {"CustomComponent", "Component", "BaseComponent"}:
continue
try:
class_node, import_nodes = find_class_ast_node(base)
if class_node is None:
continue
for import_node in import_nodes:
self.parse_imports(import_node)
nodes.append(class_node)
except Exception: # noqa: BLE001
logger.exception("Error finding base class node")
nodes.insert(0, node)
class_details = ClassCodeDetails(
name=node.name,
doc=ast.get_docstring(node),
bases=[b.__name__ for b in bases],
attributes=[],
methods=[],
init=None,
)
for _node in nodes:
self.process_class_node(_node, class_details)
self.data["classes"].append(class_details.model_dump())
def process_class_node(self, node, class_details) -> None:
for stmt in node.body:
if isinstance(stmt, ast.Assign):
if attr := self.parse_assign(stmt):
class_details.attributes.append(attr)
elif isinstance(stmt, ast.AnnAssign):
if attr := self.parse_ann_assign(stmt):
class_details.attributes.append(attr)
elif isinstance(stmt, ast.FunctionDef | ast.AsyncFunctionDef):
method, is_init = self.parse_function_def(stmt)
if is_init:
class_details.init = method
else:
class_details.methods.append(method)
def parse_global_vars(self, node: ast.Assign) -> None:
"""Extracts global variables from the code."""
global_var = {
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)
def execute_and_inspect_classes(self, code: str):
custom_component_class = eval_custom_component_code(code)
custom_component = custom_component_class(_code=code)
dunder_class = custom_component.__class__
# Get the base classes at two levels of inheritance
bases = []
for base in dunder_class.__bases__:
bases.append(base)
bases.extend(base.__bases__)
return bases
def parse_code(self) -> dict[str, Any]:
"""Runs all parsing operations and returns the resulting data."""
tree = self.get_tree()
for node in ast.walk(tree):
self.parse_node(node)
return self.data