Spaces:
Running
Running
import re | |
import asyncio | |
from typing import List | |
from collections import defaultdict | |
import gradio as gr | |
from tqdm.asyncio import tqdm as tqdm_async | |
from graphgen.models import Chunk, OpenAIModel, Tokenizer | |
from graphgen.models.storage.base_storage import BaseGraphStorage | |
from graphgen.templates import KG_EXTRACTION_PROMPT | |
from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers, | |
handle_single_entity_extraction, handle_single_relationship_extraction, | |
detect_if_chinese) | |
from graphgen.operators.merge_kg import merge_nodes, merge_edges | |
# pylint: disable=too-many-statements | |
async def extract_kg( | |
llm_client: OpenAIModel, | |
kg_instance: BaseGraphStorage, | |
tokenizer_instance: Tokenizer, | |
chunks: List[Chunk], | |
progress_bar: gr.Progress = None, | |
max_concurrent: int = 1000 | |
): | |
""" | |
:param llm_client: Synthesizer LLM model to extract entities and relationships | |
:param kg_instance | |
:param tokenizer_instance | |
:param chunks | |
:param progress_bar: Gradio progress bar to show the progress of the extraction | |
:param max_concurrent | |
:return: | |
""" | |
semaphore = asyncio.Semaphore(max_concurrent) | |
async def _process_single_content(chunk: Chunk, max_loop: int = 3): | |
async with semaphore: | |
chunk_id = chunk.id | |
content = chunk.content | |
if detect_if_chinese(content): | |
language = "Chinese" | |
else: | |
language = "English" | |
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language | |
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( | |
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content | |
) | |
final_result = await llm_client.generate_answer(hint_prompt) | |
logger.info('First result: %s', final_result) | |
history = pack_history_conversations(hint_prompt, final_result) | |
for loop_index in range(max_loop): | |
if_loop_result = await llm_client.generate_answer( | |
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], | |
history=history | |
) | |
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() | |
if if_loop_result != "yes": | |
break | |
glean_result = await llm_client.generate_answer( | |
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], | |
history=history | |
) | |
logger.info('Loop %s glean: %s', loop_index, glean_result) | |
history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result) | |
final_result += glean_result | |
if loop_index == max_loop - 1: | |
break | |
records = split_string_by_multi_markers( | |
final_result, | |
[ | |
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], | |
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]], | |
) | |
nodes = defaultdict(list) | |
edges = defaultdict(list) | |
for record in records: | |
record = re.search(r"\((.*)\)", record) | |
if record is None: | |
continue | |
record = record.group(1) # 提取括号内的内容 | |
record_attributes = split_string_by_multi_markers( | |
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] | |
) | |
entity = await handle_single_entity_extraction(record_attributes, chunk_id) | |
if entity is not None: | |
nodes[entity["entity_name"]].append(entity) | |
continue | |
relation = await handle_single_relationship_extraction(record_attributes, chunk_id) | |
if relation is not None: | |
edges[(relation["src_id"], relation["tgt_id"])].append(relation) | |
return dict(nodes), dict(edges) | |
results = [] | |
chunk_number = len(chunks) | |
async for result in tqdm_async( | |
asyncio.as_completed([_process_single_content(c) for c in chunks]), | |
total=len(chunks), | |
desc="[3/4]Extracting entities and relationships from chunks", | |
unit="chunk", | |
): | |
try: | |
if progress_bar is not None: | |
progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks") | |
results.append(await result) | |
if progress_bar is not None and len(results) == chunk_number: | |
progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks") | |
except Exception as e: # pylint: disable=broad-except | |
logger.error("Error occurred while extracting entities and relationships from chunks: %s", e) | |
nodes = defaultdict(list) | |
edges = defaultdict(list) | |
for n, e in results: | |
for k, v in n.items(): | |
nodes[k].extend(v) | |
for k, v in e.items(): | |
edges[tuple(sorted(k))].extend(v) | |
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance) | |
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance) | |
return kg_instance | |