File size: 3,908 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
from collections import defaultdict

from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage
from graphgen.utils import logger, detect_main_language
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT


async def quiz(
        synth_llm_client: OpenAIModel,
        graph_storage: NetworkXStorage,
        rephrase_storage: JsonKVStorage,
        max_samples: int = 1,
        max_concurrent: int = 1000) -> JsonKVStorage:
    """
    Get all edges and quiz them

    :param synth_llm_client: generate statements
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param max_samples: max samples for each edge
    :param max_concurrent: max concurrent
    :return:
    """

    semaphore = asyncio.Semaphore(max_concurrent)

    async def _process_single_quiz(
        des: str,
        prompt: str,
        gt: str
    ):
        async with semaphore:
            try:
                # 如果在rephrase_storage中已经存在,直接取出
                descriptions = await rephrase_storage.get_by_id(des)
                if descriptions:
                    return None

                new_description = await synth_llm_client.generate_answer(
                    prompt,
                    temperature=1
                )
                return  {des: [(new_description, gt)]}

            except Exception as e: # pylint: disable=broad-except
                logger.error("Error when quizzing description %s: %s", des, e)
                return None


    edges = await graph_storage.get_all_edges()
    nodes = await graph_storage.get_all_nodes()

    results = defaultdict(list)
    tasks = []
    for edge in edges:
        edge_data = edge[2]

        description = edge_data["description"]
        language = "English" if detect_main_language(description) == "en" else "Chinese"

        results[description] = [(description, 'yes')]

        for i in range(max_samples):
            if i > 0:
                tasks.append(
                    _process_single_quiz(description,
                                          DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
                                              input_sentence=description), 'yes')
                )
            tasks.append(_process_single_quiz(description,
                                              DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
                                                  input_sentence=description), 'no'))

    for node in nodes:
        node_data = node[1]
        description = node_data["description"]
        language = "English" if detect_main_language(description) == "en" else "Chinese"

        results[description] = [(description, 'yes')]

        for i in range(max_samples):
            if i > 0:
                tasks.append(
                    _process_single_quiz(description,
                                          DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
                                              input_sentence=description), 'yes')
                )
            tasks.append(_process_single_quiz(description,
                                              DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
                                                  input_sentence=description), 'no'))

    for result in tqdm_async(
            asyncio.as_completed(tasks),
            total=len(tasks),
            desc="Quizzing descriptions"
    ):
        new_result = await result
        if new_result:
            for key, value in new_result.items():
                results[key].extend(value)

    for key, value in results.items():
        results[key] = list(set(value))
        await rephrase_storage.upsert({key: results[key]})


    return rephrase_storage