Spaces:
Running
Running
File size: 7,012 Bytes
acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from collections import Counter
import asyncio
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.utils.format import split_string_by_multi_markers
from graphgen.utils import logger, detect_main_language
from graphgen.models import TopkTokenModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
from graphgen.templates import KG_SUMMARIZATION_PROMPT, KG_EXTRACTION_PROMPT
async def _handle_kg_summary(
entity_or_relation_name: str,
description: str,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
max_summary_tokens: int = 200
) -> str:
"""
处理实体或关系的描述信息
:param entity_or_relation_name
:param description
:param llm_client
:param tokenizer_instance
:param max_summary_tokens
:return: new description
"""
language = detect_main_language(description)
if language == "en":
language = "English"
else:
language = "Chinese"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
tokens = tokenizer_instance.encode_string(description)
if len(tokens) < max_summary_tokens:
return description
use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
entity_name=entity_or_relation_name,
description_list=use_description.split('<SEP>'),
**KG_SUMMARIZATION_PROMPT["FORMAT"]
)
new_description = await llm_client.generate_answer(prompt)
logger.info("Entity or relation %s summary: %s", entity_or_relation_name, new_description)
return new_description
async def merge_nodes(
nodes_data: dict,
kg_instance: BaseGraphStorage,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
max_concurrent: int = 1000
):
"""
Merge nodes
:param nodes_data
:param kg_instance
:param llm_client
:param tokenizer_instance
:param max_concurrent
:return
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def process_single_node(entity_name: str, node_data: list[dict]):
async with semaphore:
entity_types = []
source_ids = []
descriptions = []
node = await kg_instance.get_node(entity_name)
if node is not None:
entity_types.append(node["entity_type"])
source_ids.extend(
split_string_by_multi_markers(node["source_id"], ['<SEP>'])
)
descriptions.append(node["description"])
# 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in node_data] + entity_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = '<SEP>'.join(
sorted(set([dp["description"] for dp in node_data] + descriptions))
)
description = await _handle_kg_summary(
entity_name, description, llm_client, tokenizer_instance
)
source_id = '<SEP>'.join(
set([dp["source_id"] for dp in node_data] + source_ids)
)
node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id
}
await kg_instance.upsert_node(
entity_name,
node_data=node_data
)
node_data["entity_name"] = entity_name
return node_data
logger.info("Inserting entities into storage...")
entities_data = []
for result in tqdm_async(
asyncio.as_completed(
[process_single_node(k, v) for k, v in nodes_data.items()]
),
total=len(nodes_data),
desc="Inserting entities into storage",
unit="entity",
):
try:
entities_data.append(await result)
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while inserting entities into storage: %s", e)
async def merge_edges(
edges_data: dict,
kg_instance: BaseGraphStorage,
llm_client: TopkTokenModel,
tokenizer_instance: Tokenizer,
max_concurrent: int = 1000
):
"""
Merge edges
:param edges_data
:param kg_instance
:param llm_client
:param tokenizer_instance
:param max_concurrent
:return
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]):
async with semaphore:
source_ids = []
descriptions = []
edge = await kg_instance.get_edge(src_id, tgt_id)
if edge is not None:
source_ids.extend(
split_string_by_multi_markers(edge["source_id"], ['<SEP>'])
)
descriptions.append(edge["description"])
description = '<SEP>'.join(
sorted(set([dp["description"] for dp in edge_data] + descriptions))
)
source_id = '<SEP>'.join(
set([dp["source_id"] for dp in edge_data] + source_ids)
)
for insert_id in [src_id, tgt_id]:
if not await kg_instance.has_node(insert_id):
await kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN"
}
)
description = await _handle_kg_summary(
f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance
)
await kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={
"source_id": source_id,
"description": description
}
)
edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description
}
return edge_data
logger.info("Inserting relationships into storage...")
relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
[process_single_edge(src_id, tgt_id, v) for (src_id, tgt_id), v in edges_data.items()]
),
total=len(edges_data),
desc="Inserting relationships into storage",
unit="relationship",
):
try:
relationships_data.append(await result)
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while inserting relationships into storage: %s", e)
|