Spaces:
Running
Running
import asyncio | |
from graphgen.models import WikiSearch, OpenAIModel | |
from graphgen.models.storage.base_storage import BaseGraphStorage | |
from graphgen.templates import SEARCH_JUDGEMENT_PROMPT | |
from graphgen.utils import logger | |
async def _process_single_entity(entity_name: str, | |
description: str, | |
llm_client: OpenAIModel, | |
wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]: | |
""" | |
Process single entity | |
""" | |
search_results = await wiki_search_client.search(entity_name) | |
if not search_results: | |
return entity_name, None | |
examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"]) | |
search_results.append("None of the above") | |
search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)]) | |
prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format( | |
examples=examples, | |
entity_name=entity_name, | |
description=description, | |
search_results=search_results_str, | |
) | |
response = await llm_client.generate_answer(prompt) | |
try: | |
response = response.strip() | |
response = int(response) | |
if response < 1 or response >= len(search_results): | |
response = None | |
else: | |
response = await wiki_search_client.summary(search_results[response - 1]) | |
except ValueError: | |
response = None | |
logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response) | |
return entity_name, response | |
async def search_wikipedia(llm_client: OpenAIModel, | |
wiki_search_client: WikiSearch, | |
knowledge_graph_instance: BaseGraphStorage,) -> dict: | |
""" | |
Search wikipedia for entities | |
:param llm_client: LLM model | |
:param wiki_search_client: wiki search client | |
:param knowledge_graph_instance: knowledge graph instance | |
:return: nodes with search results | |
""" | |
nodes = await knowledge_graph_instance.get_all_nodes() | |
nodes = list(nodes) | |
wiki_data = {} | |
tasks = [ | |
_process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client) | |
for node in nodes | |
] | |
for task in asyncio.as_completed(tasks): | |
result = await task | |
wiki_data[result[0]] = result[1] | |
return wiki_data | |