Spaces:
Running
Running
import random | |
from collections import defaultdict | |
from tqdm.asyncio import tqdm as tqdm_async | |
from graphgen.utils import logger | |
from graphgen.models import NetworkXStorage, TraverseStrategy | |
async def _get_node_info( | |
node_id: str, | |
graph_storage: NetworkXStorage, | |
)-> dict: | |
""" | |
Get node info | |
:param node_id: node id | |
:param graph_storage: graph storage instance | |
:return: node info | |
""" | |
node_data = await graph_storage.get_node(node_id) | |
return { | |
"node_id": node_id, | |
**node_data | |
} | |
def _get_level_n_edges_by_max_width( | |
edge_adj_list: dict, | |
node_dict: dict, | |
edges: list, | |
nodes, | |
src_edge: tuple, | |
max_depth: int, | |
bidirectional: bool, | |
max_extra_edges: int, | |
edge_sampling: str, | |
loss_strategy: str = "only_edge" | |
) -> list: | |
""" | |
Get level n edges for an edge. | |
n is decided by max_depth in traverse_strategy | |
:param edge_adj_list | |
:param node_dict | |
:param edges | |
:param nodes | |
:param src_edge | |
:param max_depth | |
:param bidirectional | |
:param max_extra_edges | |
:param edge_sampling | |
:return: level n edges | |
""" | |
src_id, tgt_id, _ = src_edge | |
level_n_edges = [] | |
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} | |
while max_depth > 0 and max_extra_edges > 0: | |
max_depth -= 1 | |
candidate_edges = [ | |
edges[edge_id] | |
for node in start_nodes | |
for edge_id in edge_adj_list[node] | |
if not edges[edge_id][2].get("visited", False) | |
] | |
if not candidate_edges: | |
break | |
if len(candidate_edges) >= max_extra_edges: | |
if loss_strategy == "both": | |
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges] | |
candidate_edges = _sort_tuples(er_tuples, edge_sampling)[:max_extra_edges] | |
elif loss_strategy == "only_edge": | |
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[:max_extra_edges] | |
else: | |
raise ValueError(f"Invalid loss strategy: {loss_strategy}") | |
for edge in candidate_edges: | |
level_n_edges.append(edge) | |
edge[2]["visited"] = True | |
break | |
max_extra_edges -= len(candidate_edges) | |
new_start_nodes = set() | |
for edge in candidate_edges: | |
level_n_edges.append(edge) | |
edge[2]["visited"] = True | |
if not edge[0] in start_nodes: | |
new_start_nodes.add(edge[0]) | |
if not edge[1] in start_nodes: | |
new_start_nodes.add(edge[1]) | |
start_nodes = new_start_nodes | |
return level_n_edges | |
def _get_level_n_edges_by_max_tokens( | |
edge_adj_list: dict, | |
node_dict: dict, | |
edges: list, | |
nodes: list, | |
src_edge: tuple, | |
max_depth: int, | |
bidirectional: bool, | |
max_tokens: int, | |
edge_sampling: str, | |
loss_strategy: str = "only_edge" | |
) -> list: | |
""" | |
Get level n edges for an edge. | |
n is decided by max_depth in traverse_strategy. | |
:param edge_adj_list | |
:param node_dict | |
:param edges | |
:param nodes | |
:param src_edge | |
:param max_depth | |
:param bidirectional | |
:param max_tokens | |
:param edge_sampling | |
:return: level n edges | |
""" | |
src_id, tgt_id, src_edge_data = src_edge | |
max_tokens -= (src_edge_data["length"] + nodes[node_dict[src_id]][1]["length"] | |
+ nodes[node_dict[tgt_id]][1]["length"]) | |
level_n_edges = [] | |
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} | |
temp_nodes = {src_id, tgt_id} | |
while max_depth > 0 and max_tokens > 0: | |
max_depth -= 1 | |
candidate_edges = [ | |
edges[edge_id] | |
for node in start_nodes | |
for edge_id in edge_adj_list[node] | |
if not edges[edge_id][2].get("visited", False) | |
] | |
if not candidate_edges: | |
break | |
if loss_strategy == "both": | |
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in candidate_edges] | |
candidate_edges = _sort_tuples(er_tuples, edge_sampling) | |
elif loss_strategy == "only_edge": | |
candidate_edges = _sort_edges(candidate_edges, edge_sampling) | |
else: | |
raise ValueError(f"Invalid loss strategy: {loss_strategy}") | |
for edge in candidate_edges: | |
max_tokens -= edge[2]["length"] | |
if not edge[0] in temp_nodes: | |
max_tokens -= nodes[node_dict[edge[0]]][1]["length"] | |
if not edge[1] in temp_nodes: | |
max_tokens -= nodes[node_dict[edge[1]]][1]["length"] | |
if max_tokens < 0: | |
return level_n_edges | |
level_n_edges.append(edge) | |
edge[2]["visited"] = True | |
temp_nodes.add(edge[0]) | |
temp_nodes.add(edge[1]) | |
new_start_nodes = set() | |
for edge in candidate_edges: | |
if not edge[0] in start_nodes: | |
new_start_nodes.add(edge[0]) | |
if not edge[1] in start_nodes: | |
new_start_nodes.add(edge[1]) | |
start_nodes = new_start_nodes | |
return level_n_edges | |
def _sort_tuples(er_tuples: list, edge_sampling: str) -> list: | |
""" | |
Sort edges with edge sampling strategy | |
:param er_tuples: [(nodes:list, edge:tuple)] | |
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss) | |
:return: sorted edges | |
""" | |
if edge_sampling == "random": | |
er_tuples = random.sample(er_tuples, len(er_tuples)) | |
elif edge_sampling == "min_loss": | |
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"]) | |
elif edge_sampling == "max_loss": | |
er_tuples = sorted(er_tuples, key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"], | |
reverse=True) | |
else: | |
raise ValueError(f"Invalid edge sampling: {edge_sampling}") | |
edges = [edge for _, edge in er_tuples] | |
return edges | |
def _sort_edges(edges: list, edge_sampling: str) -> list: | |
""" | |
Sort edges with edge sampling strategy | |
:param edges: total edges | |
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss) | |
:return: sorted edges | |
""" | |
if edge_sampling == "random": | |
random.shuffle(edges) | |
elif edge_sampling == "min_loss": | |
edges = sorted(edges, key=lambda x: x[2]["loss"]) | |
elif edge_sampling == "max_loss": | |
edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True) | |
else: | |
raise ValueError(f"Invalid edge sampling: {edge_sampling}") | |
return edges | |
async def get_batches_with_strategy( # pylint: disable=too-many-branches | |
nodes: list, | |
edges: list, | |
graph_storage: NetworkXStorage, | |
traverse_strategy: TraverseStrategy | |
): | |
expand_method = traverse_strategy.expand_method | |
if expand_method == "max_width": | |
logger.info("Using max width strategy") | |
elif expand_method == "max_tokens": | |
logger.info("Using max tokens strategy") | |
else: | |
raise ValueError(f"Invalid expand method: {expand_method}") | |
max_depth = traverse_strategy.max_depth | |
edge_sampling = traverse_strategy.edge_sampling | |
# 构建临接矩阵 | |
edge_adj_list = defaultdict(list) | |
node_dict = {} | |
processing_batches = [] | |
node_cache = {} | |
async def get_cached_node_info(node_id: str) -> dict: | |
if node_id not in node_cache: | |
node_cache[node_id] = await _get_node_info(node_id, graph_storage) | |
return node_cache[node_id] | |
for i, (node_name, _) in enumerate(nodes): | |
node_dict[node_name] = i | |
if traverse_strategy.loss_strategy == "both": | |
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges] | |
edges = _sort_tuples(er_tuples, edge_sampling) | |
elif traverse_strategy.loss_strategy == "only_edge": | |
edges = _sort_edges(edges, edge_sampling) | |
else: | |
raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}") | |
for i, (src, tgt, _) in enumerate(edges): | |
edge_adj_list[src].append(i) | |
edge_adj_list[tgt].append(i) | |
for edge in tqdm_async(edges, desc="Preparing batches"): | |
if "visited" in edge[2] and edge[2]["visited"]: | |
continue | |
edge[2]["visited"] = True | |
_process_nodes = [] | |
_process_edges = [] | |
src_id = edge[0] | |
tgt_id = edge[1] | |
_process_nodes.extend([await get_cached_node_info(src_id), | |
await get_cached_node_info(tgt_id)]) | |
_process_edges.append(edge) | |
if expand_method == "max_width": | |
level_n_edges = _get_level_n_edges_by_max_width( | |
edge_adj_list, node_dict, edges, nodes, edge, max_depth, | |
traverse_strategy.bidirectional, traverse_strategy.max_extra_edges, | |
edge_sampling, traverse_strategy.loss_strategy | |
) | |
else: | |
level_n_edges = _get_level_n_edges_by_max_tokens( | |
edge_adj_list, node_dict, edges, nodes, edge, max_depth, | |
traverse_strategy.bidirectional, traverse_strategy.max_tokens, | |
edge_sampling, traverse_strategy.loss_strategy | |
) | |
for _edge in level_n_edges: | |
_process_nodes.append(await get_cached_node_info(_edge[0])) | |
_process_nodes.append(await get_cached_node_info(_edge[1])) | |
_process_edges.append(_edge) | |
# 去重 | |
_process_nodes = list({node['node_id']: node for node in _process_nodes}.values()) | |
_process_edges = list({(edge[0], edge[1]): edge for edge in _process_edges}.values()) | |
processing_batches.append((_process_nodes, _process_edges)) | |
logger.info("Processing batches: %d", len(processing_batches)) | |
# isolate nodes | |
isolated_node_strategy = traverse_strategy.isolated_node_strategy | |
if isolated_node_strategy == "add": | |
processing_batches = await _add_isolated_nodes(nodes, processing_batches, graph_storage) | |
logger.info("Processing batches after adding isolated nodes: %d", len(processing_batches)) | |
return processing_batches | |
async def _add_isolated_nodes( | |
nodes: list, | |
processing_batches: list, | |
graph_storage: NetworkXStorage, | |
) -> list: | |
visited_nodes = set() | |
for _process_nodes, _process_edges in processing_batches: | |
for node in _process_nodes: | |
visited_nodes.add(node["node_id"]) | |
for node in nodes: | |
if node[0] not in visited_nodes: | |
_process_nodes = [await _get_node_info(node[0], graph_storage)] | |
processing_batches.append((_process_nodes, [])) | |
return processing_batches | |