File size: 5,345 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import re
import asyncio
from typing import List
from collections import defaultdict

import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import Chunk, OpenAIModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
from graphgen.templates import KG_EXTRACTION_PROMPT
from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
                            handle_single_entity_extraction, handle_single_relationship_extraction,
                            detect_if_chinese)
from graphgen.operators.merge_kg import merge_nodes, merge_edges


# pylint: disable=too-many-statements
async def extract_kg(
        llm_client: OpenAIModel,
        kg_instance: BaseGraphStorage,
        tokenizer_instance: Tokenizer,
        chunks: List[Chunk],
        progress_bar: gr.Progress = None,
        max_concurrent: int = 1000
):
    """
    :param llm_client: Synthesizer LLM model to extract entities and relationships
    :param kg_instance
    :param tokenizer_instance
    :param chunks
    :param progress_bar: Gradio progress bar to show the progress of the extraction
    :param max_concurrent
    :return:
    """

    semaphore = asyncio.Semaphore(max_concurrent)

    async def _process_single_content(chunk: Chunk, max_loop: int = 3):
        async with semaphore:
            chunk_id = chunk.id
            content = chunk.content
            if detect_if_chinese(content):
                language = "Chinese"
            else:
                language = "English"
            KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language

            hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
                **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
            )

            final_result = await llm_client.generate_answer(hint_prompt)
            logger.info('First result: %s', final_result)

            history = pack_history_conversations(hint_prompt, final_result)
            for loop_index in range(max_loop):
                if_loop_result = await llm_client.generate_answer(
                    text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
                    history=history
                )
                if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
                if if_loop_result != "yes":
                    break

                glean_result = await llm_client.generate_answer(
                    text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
                    history=history
                )
                logger.info('Loop %s glean: %s', loop_index, glean_result)

                history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
                final_result += glean_result
                if loop_index == max_loop - 1:
                    break

            records = split_string_by_multi_markers(
                final_result,
                [
                KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
                KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
            )

            nodes = defaultdict(list)
            edges = defaultdict(list)

            for record in records:
                record = re.search(r"\((.*)\)", record)
                if record is None:
                    continue
                record = record.group(1) # 提取括号内的内容
                record_attributes = split_string_by_multi_markers(
                    record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
                )

                entity = await handle_single_entity_extraction(record_attributes, chunk_id)
                if entity is not None:
                    nodes[entity["entity_name"]].append(entity)
                    continue
                relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
                if relation is not None:
                    edges[(relation["src_id"], relation["tgt_id"])].append(relation)
            return dict(nodes), dict(edges)

    results = []
    chunk_number = len(chunks)
    async for result in tqdm_async(
        asyncio.as_completed([_process_single_content(c) for c in chunks]),
        total=len(chunks),
        desc="[3/4]Extracting entities and relationships from chunks",
        unit="chunk",
    ):
        try:
            if progress_bar is not None:
                progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
            results.append(await result)
            if progress_bar is not None and len(results) == chunk_number:
                progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
        except Exception as e: # pylint: disable=broad-except
            logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)

    nodes = defaultdict(list)
    edges = defaultdict(list)
    for n, e in results:
        for k, v in n.items():
            nodes[k].extend(v)
        for k, v in e.items():
            edges[tuple(sorted(k))].extend(v)

    await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
    await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)

    return kg_instance