File size: 6,802 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import math
import asyncio
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage
from graphgen.utils import logger, yes_no_loss_entropy
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT


async def judge_statement( # pylint: disable=too-many-statements
        trainee_llm_client: OpenAIModel,
        graph_storage: NetworkXStorage,
        rephrase_storage: JsonKVStorage,
        re_judge: bool = False,
        max_concurrent: int = 1000) -> NetworkXStorage:
    """
    Get all edges and nodes and judge them

    :param trainee_llm_client: judge the statements to get comprehension loss
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param re_judge: re-judge the relations
    :param max_concurrent: max concurrent
    :return:
    """

    semaphore = asyncio.Semaphore(max_concurrent)

    async def _judge_single_relation(
        edge: tuple,
    ):
        async with semaphore:
            source_id = edge[0]
            target_id = edge[1]
            edge_data = edge[2]

            if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
                logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
                return source_id, target_id, edge_data

            description = edge_data["description"]

            try:
                descriptions = await rephrase_storage.get_by_id(description)
                assert descriptions is not None

                judgements = []
                gts = [gt for _, gt in descriptions]
                for description, gt in descriptions:
                    judgement = await trainee_llm_client.generate_topk_per_token(
                        STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
                    )
                    judgements.append(judgement[0].top_candidates)

                loss = yes_no_loss_entropy(judgements, gts)

                logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss)

                edge_data["loss"] = loss
            except Exception as e: # pylint: disable=broad-except
                logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e)
                logger.info("Use default loss 0.1")
                edge_data["loss"] = -math.log(0.1)

            await graph_storage.update_edge(source_id, target_id, edge_data)
            return source_id, target_id, edge_data

    edges = await graph_storage.get_all_edges()

    results = []
    for result in tqdm_async(
            asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
            total=len(edges),
            desc="Judging relations"
    ):
        results.append(await result)

    async def _judge_single_entity(
        node: tuple,
    ):
        async with semaphore:
            node_id = node[0]
            node_data = node[1]

            if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
                logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
                return node_id, node_data

            description = node_data["description"]

            try:
                descriptions = await rephrase_storage.get_by_id(description)
                assert descriptions is not None

                judgements = []
                gts = [gt for _, gt in descriptions]
                for description, gt in descriptions:
                    judgement = await trainee_llm_client.generate_topk_per_token(
                        STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
                    )
                    judgements.append(judgement[0].top_candidates)

                loss = yes_no_loss_entropy(judgements, gts)

                logger.info("Node %s description: %s loss: %s", node_id, description, loss)

                node_data["loss"] = loss
            except Exception as e: # pylint: disable=broad-except
                logger.error("Error in judging entity %s: %s", node_id, e)
                logger.info("Use default loss 0.1")
                node_data["loss"] = -math.log(0.1)

            await graph_storage.update_node(node_id, node_data)
            return node_id, node_data

    nodes = await graph_storage.get_all_nodes()

    results = []
    for result in tqdm_async(
            asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
            total=len(nodes),
            desc="Judging entities"
    ):
        results.append(await result)

    return graph_storage

async def skip_judge_statement(
        graph_storage: NetworkXStorage,
        max_concurrent: int = 1000
):
    """
    Skip the judgement of the statement
    :param graph_storage: graph storage instance
    :param max_concurrent: max concurrent
    :return:
    """
    semaphore = asyncio.Semaphore(max_concurrent)

    async def _skip_single_relation(
        edge: tuple,
    ):
        async with semaphore:
            source_id = edge[0]
            target_id = edge[1]
            edge_data = edge[2]

            if "loss" in edge_data and edge_data["loss"] is not None:
                logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
                return source_id, target_id, edge_data

            edge_data["loss"] = -math.log(0.1)
            await graph_storage.update_edge(source_id, target_id, edge_data)
            return source_id, target_id, edge_data

    edges = await graph_storage.get_all_edges()
    results = []
    for result in tqdm_async(
            asyncio.as_completed([_skip_single_relation(edge) for edge in edges]),
            total=len(edges),
            desc="Skipping judgement of relations"
    ):
        results.append(await result)

    async def _skip_single_entity(
        node: tuple,
    ):
        async with semaphore:
            node_id = node[0]
            node_data = node[1]

            if "loss" in node_data and node_data["loss"] is not None:
                logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
                return node_id, node_data

            node_data["loss"] = -math.log(0.1)
            await graph_storage.update_node(node_id, node_data)
            return node_id, node_data

    nodes = await graph_storage.get_all_nodes()
    results = []
    for result in tqdm_async(
            asyncio.as_completed([_skip_single_entity(node) for node in nodes]),
            total=len(nodes),
            desc="Skipping judgement of entities"
    ):
        results.append(await result)

    return graph_storage