Spaces:
Running
Running
import math | |
import asyncio | |
from tqdm.asyncio import tqdm as tqdm_async | |
from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage | |
from graphgen.utils import logger, yes_no_loss_entropy | |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT | |
async def judge_statement( # pylint: disable=too-many-statements | |
trainee_llm_client: OpenAIModel, | |
graph_storage: NetworkXStorage, | |
rephrase_storage: JsonKVStorage, | |
re_judge: bool = False, | |
max_concurrent: int = 1000) -> NetworkXStorage: | |
""" | |
Get all edges and nodes and judge them | |
:param trainee_llm_client: judge the statements to get comprehension loss | |
:param graph_storage: graph storage instance | |
:param rephrase_storage: rephrase storage instance | |
:param re_judge: re-judge the relations | |
:param max_concurrent: max concurrent | |
:return: | |
""" | |
semaphore = asyncio.Semaphore(max_concurrent) | |
async def _judge_single_relation( | |
edge: tuple, | |
): | |
async with semaphore: | |
source_id = edge[0] | |
target_id = edge[1] | |
edge_data = edge[2] | |
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: | |
logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"]) | |
return source_id, target_id, edge_data | |
description = edge_data["description"] | |
try: | |
descriptions = await rephrase_storage.get_by_id(description) | |
assert descriptions is not None | |
judgements = [] | |
gts = [gt for _, gt in descriptions] | |
for description, gt in descriptions: | |
judgement = await trainee_llm_client.generate_topk_per_token( | |
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description) | |
) | |
judgements.append(judgement[0].top_candidates) | |
loss = yes_no_loss_entropy(judgements, gts) | |
logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss) | |
edge_data["loss"] = loss | |
except Exception as e: # pylint: disable=broad-except | |
logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e) | |
logger.info("Use default loss 0.1") | |
edge_data["loss"] = -math.log(0.1) | |
await graph_storage.update_edge(source_id, target_id, edge_data) | |
return source_id, target_id, edge_data | |
edges = await graph_storage.get_all_edges() | |
results = [] | |
for result in tqdm_async( | |
asyncio.as_completed([_judge_single_relation(edge) for edge in edges]), | |
total=len(edges), | |
desc="Judging relations" | |
): | |
results.append(await result) | |
async def _judge_single_entity( | |
node: tuple, | |
): | |
async with semaphore: | |
node_id = node[0] | |
node_data = node[1] | |
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: | |
logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"]) | |
return node_id, node_data | |
description = node_data["description"] | |
try: | |
descriptions = await rephrase_storage.get_by_id(description) | |
assert descriptions is not None | |
judgements = [] | |
gts = [gt for _, gt in descriptions] | |
for description, gt in descriptions: | |
judgement = await trainee_llm_client.generate_topk_per_token( | |
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description) | |
) | |
judgements.append(judgement[0].top_candidates) | |
loss = yes_no_loss_entropy(judgements, gts) | |
logger.info("Node %s description: %s loss: %s", node_id, description, loss) | |
node_data["loss"] = loss | |
except Exception as e: # pylint: disable=broad-except | |
logger.error("Error in judging entity %s: %s", node_id, e) | |
logger.info("Use default loss 0.1") | |
node_data["loss"] = -math.log(0.1) | |
await graph_storage.update_node(node_id, node_data) | |
return node_id, node_data | |
nodes = await graph_storage.get_all_nodes() | |
results = [] | |
for result in tqdm_async( | |
asyncio.as_completed([_judge_single_entity(node) for node in nodes]), | |
total=len(nodes), | |
desc="Judging entities" | |
): | |
results.append(await result) | |
return graph_storage | |
async def skip_judge_statement( | |
graph_storage: NetworkXStorage, | |
max_concurrent: int = 1000 | |
): | |
""" | |
Skip the judgement of the statement | |
:param graph_storage: graph storage instance | |
:param max_concurrent: max concurrent | |
:return: | |
""" | |
semaphore = asyncio.Semaphore(max_concurrent) | |
async def _skip_single_relation( | |
edge: tuple, | |
): | |
async with semaphore: | |
source_id = edge[0] | |
target_id = edge[1] | |
edge_data = edge[2] | |
if "loss" in edge_data and edge_data["loss"] is not None: | |
logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"]) | |
return source_id, target_id, edge_data | |
edge_data["loss"] = -math.log(0.1) | |
await graph_storage.update_edge(source_id, target_id, edge_data) | |
return source_id, target_id, edge_data | |
edges = await graph_storage.get_all_edges() | |
results = [] | |
for result in tqdm_async( | |
asyncio.as_completed([_skip_single_relation(edge) for edge in edges]), | |
total=len(edges), | |
desc="Skipping judgement of relations" | |
): | |
results.append(await result) | |
async def _skip_single_entity( | |
node: tuple, | |
): | |
async with semaphore: | |
node_id = node[0] | |
node_data = node[1] | |
if "loss" in node_data and node_data["loss"] is not None: | |
logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"]) | |
return node_id, node_data | |
node_data["loss"] = -math.log(0.1) | |
await graph_storage.update_node(node_id, node_data) | |
return node_id, node_data | |
nodes = await graph_storage.get_all_nodes() | |
results = [] | |
for result in tqdm_async( | |
asyncio.as_completed([_skip_single_entity(node) for node in nodes]), | |
total=len(nodes), | |
desc="Skipping judgement of entities" | |
): | |
results.append(await result) | |
return graph_storage | |