Spaces:
Running
Running
import asyncio | |
from collections import defaultdict | |
from tqdm.asyncio import tqdm as tqdm_async | |
from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage | |
from graphgen.utils import logger, detect_main_language | |
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT | |
async def quiz( | |
synth_llm_client: OpenAIModel, | |
graph_storage: NetworkXStorage, | |
rephrase_storage: JsonKVStorage, | |
max_samples: int = 1, | |
max_concurrent: int = 1000) -> JsonKVStorage: | |
""" | |
Get all edges and quiz them | |
:param synth_llm_client: generate statements | |
:param graph_storage: graph storage instance | |
:param rephrase_storage: rephrase storage instance | |
:param max_samples: max samples for each edge | |
:param max_concurrent: max concurrent | |
:return: | |
""" | |
semaphore = asyncio.Semaphore(max_concurrent) | |
async def _process_single_quiz( | |
des: str, | |
prompt: str, | |
gt: str | |
): | |
async with semaphore: | |
try: | |
# 如果在rephrase_storage中已经存在,直接取出 | |
descriptions = await rephrase_storage.get_by_id(des) | |
if descriptions: | |
return None | |
new_description = await synth_llm_client.generate_answer( | |
prompt, | |
temperature=1 | |
) | |
return {des: [(new_description, gt)]} | |
except Exception as e: # pylint: disable=broad-except | |
logger.error("Error when quizzing description %s: %s", des, e) | |
return None | |
edges = await graph_storage.get_all_edges() | |
nodes = await graph_storage.get_all_nodes() | |
results = defaultdict(list) | |
tasks = [] | |
for edge in edges: | |
edge_data = edge[2] | |
description = edge_data["description"] | |
language = "English" if detect_main_language(description) == "en" else "Chinese" | |
results[description] = [(description, 'yes')] | |
for i in range(max_samples): | |
if i > 0: | |
tasks.append( | |
_process_single_quiz(description, | |
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( | |
input_sentence=description), 'yes') | |
) | |
tasks.append(_process_single_quiz(description, | |
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( | |
input_sentence=description), 'no')) | |
for node in nodes: | |
node_data = node[1] | |
description = node_data["description"] | |
language = "English" if detect_main_language(description) == "en" else "Chinese" | |
results[description] = [(description, 'yes')] | |
for i in range(max_samples): | |
if i > 0: | |
tasks.append( | |
_process_single_quiz(description, | |
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( | |
input_sentence=description), 'yes') | |
) | |
tasks.append(_process_single_quiz(description, | |
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( | |
input_sentence=description), 'no')) | |
for result in tqdm_async( | |
asyncio.as_completed(tasks), | |
total=len(tasks), | |
desc="Quizzing descriptions" | |
): | |
new_result = await result | |
if new_result: | |
for key, value in new_result.items(): | |
results[key].extend(value) | |
for key, value in results.items(): | |
results[key] = list(set(value)) | |
await rephrase_storage.upsert({key: results[key]}) | |
return rephrase_storage | |