GraphGen / graphgen /operators /split_graph.py
chenzihong-gavin
init
acd7cf4
import random
from collections import defaultdict
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.utils import logger
from graphgen.models import NetworkXStorage, TraverseStrategy
async def _get_node_info(
node_id: str,
graph_storage: NetworkXStorage,
)-> dict:
"""
Get node info
:param node_id: node id
:param graph_storage: graph storage instance
:return: node info
"""
node_data = await graph_storage.get_node(node_id)
return {
"node_id": node_id,
**node_data
}
def _get_level_n_edges_by_max_width(
edge_adj_list: dict,
node_dict: dict,
edges: list,
nodes,
src_edge: tuple,
max_depth: int,
bidirectional: bool,
max_extra_edges: int,
edge_sampling: str,
loss_strategy: str = "only_edge"
) -> list:
"""
Get level n edges for an edge.
n is decided by max_depth in traverse_strategy
:param edge_adj_list
:param node_dict
:param edges
:param nodes
:param src_edge
:param max_depth
:param bidirectional
:param max_extra_edges
:param edge_sampling
:return: level n edges
"""
src_id, tgt_id, _ = src_edge
level_n_edges = []
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
while max_depth > 0 and max_extra_edges > 0:
max_depth -= 1
candidate_edges = [
edges[edge_id]
for node in start_nodes
for edge_id in edge_adj_list[node]
if not edges[edge_id][2].get("visited", False)
]
if not candidate_edges:
break
if len(candidate_edges) >= max_extra_edges:
if loss_strategy == "both":
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges]
elif loss_strategy == "only_edge":
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges]
else:
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
for edge in candidate_edges:
level_n_edges.append(edge)
edge[2]["visited"] = True
break
max_extra_edges -= len(candidate_edges)
new_start_nodes = set()
for edge in candidate_edges:
level_n_edges.append(edge)
edge[2]["visited"] = True
if not edge[0] in start_nodes:
new_start_nodes.add(edge[0])
if not edge[1] in start_nodes:
new_start_nodes.add(edge[1])
start_nodes = new_start_nodes
return level_n_edges
def _get_level_n_edges_by_max_tokens(
edge_adj_list: dict,
node_dict: dict,
edges: list,
nodes: list,
src_edge: tuple,
max_depth: int,
bidirectional: bool,
max_tokens: int,
edge_sampling: str,
loss_strategy: str = "only_edge"
) -> list:
"""
Get level n edges for an edge.
n is decided by max_depth in traverse_strategy.
:param edge_adj_list
:param node_dict
:param edges
:param nodes
:param src_edge
:param max_depth
:param bidirectional
:param max_tokens
:param edge_sampling
:return: level n edges
"""
src_id, tgt_id, src_edge_data = src_edge
max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"]
+ nodes[node_dict[tgt_id]][1]["length"])
level_n_edges = []
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
temp_nodes = {src_id, tgt_id}
while max_depth > 0 and max_tokens > 0:
max_depth -= 1
candidate_edges = [
edges[edge_id]
for node in start_nodes
for edge_id in edge_adj_list[node]
if not edges[edge_id][2].get("visited", False)
]
if not candidate_edges:
break
if loss_strategy == "both":
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges]
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
elif loss_strategy == "only_edge":
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
else:
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
for edge in candidate_edges:
max_tokens -= edge[2]["length"]
if not edge[0] in temp_nodes:
max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
if not edge[1] in temp_nodes:
max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
if max_tokens < 0:
return level_n_edges
level_n_edges.append(edge)
edge[2]["visited"] = True
temp_nodes.add(edge[0])
temp_nodes.add(edge[1])
new_start_nodes = set()
for edge in candidate_edges:
if not edge[0] in start_nodes:
new_start_nodes.add(edge[0])
if not edge[1] in start_nodes:
new_start_nodes.add(edge[1])
start_nodes = new_start_nodes
return level_n_edges
def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
"""
Sort edges with edge sampling strategy
:param er_tuples: [(nodes:list, edge:tuple)]
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
:return: sorted edges
"""
if edge_sampling == "random":
er_tuples = random.sample(er_tuples, len(er_tuples))
elif edge_sampling == "min_loss":
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"])
elif edge_sampling == "max_loss":
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
reverse=True)
else:
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
edges = [edge for _, edge in er_tuples]
return edges
def _sort_edges(edges: list, edge_sampling: str) -> list:
"""
Sort edges with edge sampling strategy
:param edges: total edges
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
:return: sorted edges
"""
if edge_sampling == "random":
random.shuffle(edges)
elif edge_sampling == "min_loss":
edges = sorted(edges, key=lambda x: x[2]["loss"])
elif edge_sampling == "max_loss":
edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
else:
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
return edges
async def get_batches_with_strategy( # pylint: disable=too-many-branches
nodes: list,
edges: list,
graph_storage: NetworkXStorage,
traverse_strategy: TraverseStrategy
):
expand_method = traverse_strategy.expand_method
if expand_method == "max_width":
logger.info("Using max width strategy")
elif expand_method == "max_tokens":
logger.info("Using max tokens strategy")
else:
raise ValueError(f"Invalid expand method: {expand_method}")
max_depth = traverse_strategy.max_depth
edge_sampling = traverse_strategy.edge_sampling
# 构建临接矩阵
edge_adj_list = defaultdict(list)
node_dict = {}
processing_batches = []
node_cache = {}
async def get_cached_node_info(node_id: str) -> dict:
if node_id not in node_cache:
node_cache[node_id] = await _get_node_info(node_id, graph_storage)
return node_cache[node_id]
for i, (node_name, _) in enumerate(nodes):
node_dict[node_name] = i
if traverse_strategy.loss_strategy == "both":
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
edges = _sort_tuples(er_tuples, edge_sampling)
elif traverse_strategy.loss_strategy == "only_edge":
edges = _sort_edges(edges, edge_sampling)
else:
raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
for i, (src, tgt, _) in enumerate(edges):
edge_adj_list[src].append(i)
edge_adj_list[tgt].append(i)
for edge in tqdm_async(edges, desc="Preparing batches"):
if "visited" in edge[2] and edge[2]["visited"]:
continue
edge[2]["visited"] = True
_process_nodes = []
_process_edges = []
src_id = edge[0]
tgt_id = edge[1]
_process_nodes.extend([await get_cached_node_info(src_id),
await get_cached_node_info(tgt_id)])
_process_edges.append(edge)
if expand_method == "max_width":
level_n_edges = _get_level_n_edges_by_max_width(
edge_adj_list, node_dict, edges, nodes, edge, max_depth,
traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
edge_sampling, traverse_strategy.loss_strategy
)
else:
level_n_edges = _get_level_n_edges_by_max_tokens(
edge_adj_list, node_dict, edges, nodes, edge, max_depth,
traverse_strategy.bidirectional, traverse_strategy.max_tokens,
edge_sampling, traverse_strategy.loss_strategy
)
for _edge in level_n_edges:
_process_nodes.append(await get_cached_node_info(_edge[0]))
_process_nodes.append(await get_cached_node_info(_edge[1]))
_process_edges.append(_edge)
# 去重
_process_nodes = list({node['node_id']: node for node in _process_nodes}.values())
_process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values())
processing_batches.append((_process_nodes, _process_edges))
logger.info("Processing batches: %d", len(processing_batches))
# isolate nodes
isolated_node_strategy = traverse_strategy.isolated_node_strategy
if isolated_node_strategy == "add":
processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage)
logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches))
return processing_batches
async def _add_isolated_nodes(
nodes: list,
processing_batches: list,
graph_storage: NetworkXStorage,
) -> list:
visited_nodes = set()
for _process_nodes, _process_edges in processing_batches:
for node in _process_nodes:
visited_nodes.add(node["node_id"])
for node in nodes:
if node[0] not in visited_nodes:
_process_nodes = [await _get_node_info(node[0], graph_storage)]
processing_batches.append((_process_nodes, []))
return processing_batches