import random from collections import defaultdict from tqdm.asyncio import tqdm as tqdm_async from graphgen.utils import logger from graphgen.models import NetworkXStorage, TraverseStrategy async def _get_node_info( node_id: str, graph_storage: NetworkXStorage, )-> dict: """ Get node info :param node_id: node id :param graph_storage: graph storage instance :return: node info """ node_data = await graph_storage.get_node(node_id) return { "node_id": node_id, **node_data } def _get_level_n_edges_by_max_width( edge_adj_list: dict, node_dict: dict, edges: list, nodes, src_edge: tuple, max_depth: int, bidirectional: bool, max_extra_edges: int, edge_sampling: str, loss_strategy: str = "only_edge" ) -> list: """ Get level n edges for an edge. n is decided by max_depth in traverse_strategy :param edge_adj_list :param node_dict :param edges :param nodes :param src_edge :param max_depth :param bidirectional :param max_extra_edges :param edge_sampling :return: level n edges """ src_id, tgt_id, _ = src_edge level_n_edges = [] start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} while max_depth > 0 and max_extra_edges > 0: max_depth -= 1 candidate_edges = [ edges[edge_id] for node in start_nodes for edge_id in edge_adj_list[node] if not edges[edge_id][2].get("visited", False) ] if not candidate_edges: break if len(candidate_edges) >= max_extra_edges: if loss_strategy == "both": er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges] candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges] elif loss_strategy == "only_edge": candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges] else: raise ValueError(f"Invalid loss strategy: {loss_strategy}") for edge in candidate_edges: level_n_edges.append(edge) edge[2]["visited"] = True break max_extra_edges -= len(candidate_edges) new_start_nodes = set() for edge in candidate_edges: level_n_edges.append(edge) edge[2]["visited"] = True if not edge[0] in start_nodes: new_start_nodes.add(edge[0]) if not edge[1] in start_nodes: new_start_nodes.add(edge[1]) start_nodes = new_start_nodes return level_n_edges def _get_level_n_edges_by_max_tokens( edge_adj_list: dict, node_dict: dict, edges: list, nodes: list, src_edge: tuple, max_depth: int, bidirectional: bool, max_tokens: int, edge_sampling: str, loss_strategy: str = "only_edge" ) -> list: """ Get level n edges for an edge. n is decided by max_depth in traverse_strategy. :param edge_adj_list :param node_dict :param edges :param nodes :param src_edge :param max_depth :param bidirectional :param max_tokens :param edge_sampling :return: level n edges """ src_id, tgt_id, src_edge_data = src_edge max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"] + nodes[node_dict[tgt_id]][1]["length"]) level_n_edges = [] start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} temp_nodes = {src_id, tgt_id} while max_depth > 0 and max_tokens > 0: max_depth -= 1 candidate_edges = [ edges[edge_id] for node in start_nodes for edge_id in edge_adj_list[node] if not edges[edge_id][2].get("visited", False) ] if not candidate_edges: break if loss_strategy == "both": er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges] candidate_edges = _sort_tuples(er_tuples, edge_sampling) elif loss_strategy == "only_edge": candidate_edges = _sort_edges(candidate_edges, edge_sampling) else: raise ValueError(f"Invalid loss strategy: {loss_strategy}") for edge in candidate_edges: max_tokens -= edge[2]["length"] if not edge[0] in temp_nodes: max_tokens -= nodes[node_dict[edge[0]]][1]["length"] if not edge[1] in temp_nodes: max_tokens -= nodes[node_dict[edge[1]]][1]["length"] if max_tokens < 0: return level_n_edges level_n_edges.append(edge) edge[2]["visited"] = True temp_nodes.add(edge[0]) temp_nodes.add(edge[1]) new_start_nodes = set() for edge in candidate_edges: if not edge[0] in start_nodes: new_start_nodes.add(edge[0]) if not edge[1] in start_nodes: new_start_nodes.add(edge[1]) start_nodes = new_start_nodes return level_n_edges def _sort_tuples(er_tuples: list, edge_sampling: str) -> list: """ Sort edges with edge sampling strategy :param er_tuples: [(nodes:list, edge:tuple)] :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) :return: sorted edges """ if edge_sampling == "random": er_tuples = random.sample(er_tuples, len(er_tuples)) elif edge_sampling == "min_loss": er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"]) elif edge_sampling == "max_loss": er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"], reverse=True) else: raise ValueError(f"Invalid edge sampling: {edge_sampling}") edges = [edge for _, edge in er_tuples] return edges def _sort_edges(edges: list, edge_sampling: str) -> list: """ Sort edges with edge sampling strategy :param edges: total edges :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) :return: sorted edges """ if edge_sampling == "random": random.shuffle(edges) elif edge_sampling == "min_loss": edges = sorted(edges, key=lambda x: x[2]["loss"]) elif edge_sampling == "max_loss": edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True) else: raise ValueError(f"Invalid edge sampling: {edge_sampling}") return edges async def get_batches_with_strategy( # pylint: disable=too-many-branches nodes: list, edges: list, graph_storage: NetworkXStorage, traverse_strategy: TraverseStrategy ): expand_method = traverse_strategy.expand_method if expand_method == "max_width": logger.info("Using max width strategy") elif expand_method == "max_tokens": logger.info("Using max tokens strategy") else: raise ValueError(f"Invalid expand method: {expand_method}") max_depth = traverse_strategy.max_depth edge_sampling = traverse_strategy.edge_sampling # 构建临接矩阵 edge_adj_list = defaultdict(list) node_dict = {} processing_batches = [] node_cache = {} async def get_cached_node_info(node_id: str) -> dict: if node_id not in node_cache: node_cache[node_id] = await _get_node_info(node_id, graph_storage) return node_cache[node_id] for i, (node_name, _) in enumerate(nodes): node_dict[node_name] = i if traverse_strategy.loss_strategy == "both": er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges] edges = _sort_tuples(er_tuples, edge_sampling) elif traverse_strategy.loss_strategy == "only_edge": edges = _sort_edges(edges, edge_sampling) else: raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}") for i, (src, tgt, _) in enumerate(edges): edge_adj_list[src].append(i) edge_adj_list[tgt].append(i) for edge in tqdm_async(edges, desc="Preparing batches"): if "visited" in edge[2] and edge[2]["visited"]: continue edge[2]["visited"] = True _process_nodes = [] _process_edges = [] src_id = edge[0] tgt_id = edge[1] _process_nodes.extend([await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]) _process_edges.append(edge) if expand_method == "max_width": level_n_edges = _get_level_n_edges_by_max_width( edge_adj_list, node_dict, edges, nodes, edge, max_depth, traverse_strategy.bidirectional, traverse_strategy.max_extra_edges, edge_sampling, traverse_strategy.loss_strategy ) else: level_n_edges = _get_level_n_edges_by_max_tokens( edge_adj_list, node_dict, edges, nodes, edge, max_depth, traverse_strategy.bidirectional, traverse_strategy.max_tokens, edge_sampling, traverse_strategy.loss_strategy ) for _edge in level_n_edges: _process_nodes.append(await get_cached_node_info(_edge[0])) _process_nodes.append(await get_cached_node_info(_edge[1])) _process_edges.append(_edge) # 去重 _process_nodes = list({node['node_id']: node for node in _process_nodes}.values()) _process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values()) processing_batches.append((_process_nodes, _process_edges)) logger.info("Processing batches: %d", len(processing_batches)) # isolate nodes isolated_node_strategy = traverse_strategy.isolated_node_strategy if isolated_node_strategy == "add": processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage) logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches)) return processing_batches async def _add_isolated_nodes( nodes: list, processing_batches: list, graph_storage: NetworkXStorage, ) -> list: visited_nodes = set() for _process_nodes, _process_edges in processing_batches: for node in _process_nodes: visited_nodes.add(node["node_id"]) for node in nodes: if node[0] not in visited_nodes: _process_nodes = [await _get_node_info(node[0], graph_storage)] processing_batches.append((_process_nodes, [])) return processing_batches