Spaces:
Running
Running
File size: 5,345 Bytes
acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import re
import asyncio
from typing import List
from collections import defaultdict
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import Chunk, OpenAIModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
from graphgen.templates import KG_EXTRACTION_PROMPT
from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
handle_single_entity_extraction, handle_single_relationship_extraction,
detect_if_chinese)
from graphgen.operators.merge_kg import merge_nodes, merge_edges
# pylint: disable=too-many-statements
async def extract_kg(
llm_client: OpenAIModel,
kg_instance: BaseGraphStorage,
tokenizer_instance: Tokenizer,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
max_concurrent: int = 1000
):
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
:param tokenizer_instance
:param chunks
:param progress_bar: Gradio progress bar to show the progress of the extraction
:param max_concurrent
:return:
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
async with semaphore:
chunk_id = chunk.id
content = chunk.content
if detect_if_chinese(content):
language = "Chinese"
else:
language = "English"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
)
final_result = await llm_client.generate_answer(hint_prompt)
logger.info('First result: %s', final_result)
history = pack_history_conversations(hint_prompt, final_result)
for loop_index in range(max_loop):
if_loop_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
glean_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
history=history
)
logger.info('Loop %s glean: %s', loop_index, glean_result)
history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
final_result += glean_result
if loop_index == max_loop - 1:
break
records = split_string_by_multi_markers(
final_result,
[
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
)
nodes = defaultdict(list)
edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1) # 提取括号内的内容
record_attributes = split_string_by_multi_markers(
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
entity = await handle_single_entity_extraction(record_attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
if relation is not None:
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
return dict(nodes), dict(edges)
results = []
chunk_number = len(chunks)
async for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in chunks]),
total=len(chunks),
desc="[3/4]Extracting entities and relationships from chunks",
unit="chunk",
):
try:
if progress_bar is not None:
progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
results.append(await result)
if progress_bar is not None and len(results) == chunk_number:
progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
nodes = defaultdict(list)
edges = defaultdict(list)
for n, e in results:
for k, v in n.items():
nodes[k].extend(v)
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
return kg_instance
|