import os import html from typing import Any, Union, cast, Optional from dataclasses import dataclass import networkx as nx from graphgen.utils import logger from .base_storage import BaseGraphStorage @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod def load_nx_graph(file_name) -> Optional[nx.Graph]: if os.path.exists(file_name): return nx.read_graphml(file_name) return None @staticmethod def write_nx_graph(graph: nx.Graph, file_name): logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges()) nx.write_graphml(graph, file_name) @staticmethod def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Return the largest connected component of the graph, with nodes and edges sorted in a stable way. """ from graspologic.utils import largest_connected_component graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) node_mapping = { node: html.unescape(node.upper().strip()) for node in graph.nodes() } # type: ignore graph = nx.relabel_nodes(graph, node_mapping) return NetworkXStorage._stabilize_graph(graph) @staticmethod def _stabilize_graph(graph: nx.Graph) -> nx.Graph: """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Ensure an undirected graph with the same relationships will always be read the same way. 通过对节点和边进行排序来实现 """ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() sorted_nodes = graph.nodes(data=True) sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) fixed_graph.add_nodes_from(sorted_nodes) edges = list(graph.edges(data=True)) if not graph.is_directed(): def _sort_source_target(edge): source, target, edge_data = edge if source > target: source, target = target, source return source, target, edge_data edges = [_sort_source_target(edge) for edge in edges] def _get_edge_key(source: Any, target: Any) -> str: return f"{source} -> {target}" edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) fixed_graph.add_edges_from(edges) return fixed_graph def __post_init__(self): """ 如果图文件存在,则加载图文件,否则创建一个新图 """ self._graphml_xml_file = os.path.join( self.working_dir, f"{self.namespace}.graphml" ) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( "Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file, preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges() ) self._graph = preloaded_graph or nx.Graph() async def index_done_callback(self): NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: return self._graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> Union[dict, None]: return self._graph.nodes.get(node_id) async def get_all_nodes(self) -> Union[list[dict], None]: return self._graph.nodes(data=True) async def node_degree(self, node_id: str) -> int: return self._graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: return self._graph.degree(src_id) + self._graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: return self._graph.edges.get((source_node_id, target_node_id)) async def get_all_edges(self) -> Union[list[dict], None]: return self._graph.edges(data=True) async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id, data=True)) return None async def get_graph(self) -> nx.Graph: return self._graph async def upsert_node(self, node_id: str, node_data: dict[str, str]): self._graph.add_node(node_id, **node_data) async def update_node(self, node_id: str, node_data: dict[str, str]): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: logger.warning("Node %s not found in the graph for update.", node_id) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) else: logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id) async def delete_node(self, node_id: str): """ Delete a node from the graph based on the specified node_id. :param node_id: The node_id to delete """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) logger.info("Node %s deleted from the graph.", node_id) else: logger.warning("Node %s not found in the graph for deletion.", node_id) async def clear(self): """ Clear the graph by removing all nodes and edges. """ self._graph.clear() logger.info("Graph %s cleared.", self.namespace)