Spaces:
Running
Running
File size: 3,908 Bytes
acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import asyncio
from collections import defaultdict
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage
from graphgen.utils import logger, detect_main_language
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
async def quiz(
synth_llm_client: OpenAIModel,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
max_samples: int = 1,
max_concurrent: int = 1000) -> JsonKVStorage:
"""
Get all edges and quiz them
:param synth_llm_client: generate statements
:param graph_storage: graph storage instance
:param rephrase_storage: rephrase storage instance
:param max_samples: max samples for each edge
:param max_concurrent: max concurrent
:return:
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def _process_single_quiz(
des: str,
prompt: str,
gt: str
):
async with semaphore:
try:
# 如果在rephrase_storage中已经存在,直接取出
descriptions = await rephrase_storage.get_by_id(des)
if descriptions:
return None
new_description = await synth_llm_client.generate_answer(
prompt,
temperature=1
)
return {des: [(new_description, gt)]}
except Exception as e: # pylint: disable=broad-except
logger.error("Error when quizzing description %s: %s", des, e)
return None
edges = await graph_storage.get_all_edges()
nodes = await graph_storage.get_all_nodes()
results = defaultdict(list)
tasks = []
for edge in edges:
edge_data = edge[2]
description = edge_data["description"]
language = "English" if detect_main_language(description) == "en" else "Chinese"
results[description] = [(description, 'yes')]
for i in range(max_samples):
if i > 0:
tasks.append(
_process_single_quiz(description,
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
input_sentence=description), 'yes')
)
tasks.append(_process_single_quiz(description,
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
input_sentence=description), 'no'))
for node in nodes:
node_data = node[1]
description = node_data["description"]
language = "English" if detect_main_language(description) == "en" else "Chinese"
results[description] = [(description, 'yes')]
for i in range(max_samples):
if i > 0:
tasks.append(
_process_single_quiz(description,
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
input_sentence=description), 'yes')
)
tasks.append(_process_single_quiz(description,
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
input_sentence=description), 'no'))
for result in tqdm_async(
asyncio.as_completed(tasks),
total=len(tasks),
desc="Quizzing descriptions"
):
new_result = await result
if new_result:
for key, value in new_result.items():
results[key].extend(value)
for key, value in results.items():
results[key] = list(set(value))
await rephrase_storage.upsert({key: results[key]})
return rephrase_storage
|