GraphGen / graphgen /operators /extract_kg.py
chenzihong-gavin
init
acd7cf4
import re
import asyncio
from typing import List
from collections import defaultdict
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import Chunk, OpenAIModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
from graphgen.templates import KG_EXTRACTION_PROMPT
from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers,
handle_single_entity_extraction, handle_single_relationship_extraction,
detect_if_chinese)
from graphgen.operators.merge_kg import merge_nodes, merge_edges
# pylint: disable=too-many-statements
async def extract_kg(
llm_client: OpenAIModel,
kg_instance: BaseGraphStorage,
tokenizer_instance: Tokenizer,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
max_concurrent: int = 1000
):
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
:param tokenizer_instance
:param chunks
:param progress_bar: Gradio progress bar to show the progress of the extraction
:param max_concurrent
:return:
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
async with semaphore:
chunk_id = chunk.id
content = chunk.content
if detect_if_chinese(content):
language = "Chinese"
else:
language = "English"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
)
final_result = await llm_client.generate_answer(hint_prompt)
logger.info('First result: %s', final_result)
history = pack_history_conversations(hint_prompt, final_result)
for loop_index in range(max_loop):
if_loop_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"],
history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
glean_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"],
history=history
)
logger.info('Loop %s glean: %s', loop_index, glean_result)
history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result)
final_result += glean_result
if loop_index == max_loop - 1:
break
records = split_string_by_multi_markers(
final_result,
[
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]],
)
nodes = defaultdict(list)
edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1) # 提取括号内的内容
record_attributes = split_string_by_multi_markers(
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
entity = await handle_single_entity_extraction(record_attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
relation = await handle_single_relationship_extraction(record_attributes, chunk_id)
if relation is not None:
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
return dict(nodes), dict(edges)
results = []
chunk_number = len(chunks)
async for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in chunks]),
total=len(chunks),
desc="[3/4]Extracting entities and relationships from chunks",
unit="chunk",
):
try:
if progress_bar is not None:
progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
results.append(await result)
if progress_bar is not None and len(results) == chunk_number:
progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
nodes = defaultdict(list)
edges = defaultdict(list)
for n, e in results:
for k, v in n.items():
nodes[k].extend(v)
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
return kg_instance