# Adapt from https://github.com/HKUDS/LightRAG import asyncio import os import time from dataclasses import dataclass, field from typing import List, Union, cast import gradio as gr from tqdm.asyncio import tqdm as tqdm_async from .models import ( Chunk, JsonKVStorage, NetworkXStorage, OpenAIModel, Tokenizer, TraverseStrategy, WikiSearch, ) from .models.storage.base_storage import StorageNameSpace from .operators import ( extract_kg, judge_statement, quiz, search_wikipedia, skip_judge_statement, traverse_graph_atomically, traverse_graph_by_edge, traverse_graph_for_multi_hop, ) from .utils import compute_content_hash, create_event_loop, logger sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @dataclass class GraphGen: unique_id: int = int(time.time()) working_dir: str = os.path.join(sys_path, "cache") # text chunking chunk_size: int = 1024 chunk_overlap_size: int = 100 # llm synthesizer_llm_client: OpenAIModel = None trainee_llm_client: OpenAIModel = None tokenizer_instance: Tokenizer = None # web search if_web_search: bool = False wiki_client: WikiSearch = field(default_factory=WikiSearch) # traverse strategy traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy) # webui progress_bar: gr.Progress = None def __post_init__(self): self.full_docs_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="full_docs" ) self.text_chunks_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="text_chunks" ) self.wiki_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="wiki" ) self.graph_storage: NetworkXStorage = NetworkXStorage( self.working_dir, namespace="graph" ) self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" ) self.qa_storage: JsonKVStorage = JsonKVStorage( os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}" ) async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict: # TODO: 是否进行指代消解 if len(data) == 0: return {} new_docs = {} inserting_chunks = {} if data_type == "raw": assert isinstance(data, list) and isinstance(data[0], dict) # compute hash for each document new_docs = { compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data } _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if len(new_docs) == 0: logger.warning("All docs are already in the storage") return {} logger.info("[New Docs] inserting %d docs", len(new_docs)) cur_index = 1 doc_number = len(new_docs) async for doc_key, doc in tqdm_async( new_docs.items(), desc="[1/4]Chunking documents", unit="doc" ): chunks = { compute_content_hash(dp["content"], prefix="chunk-"): { **dp, 'full_doc_id': doc_key } for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"], self.chunk_overlap_size, self.chunk_size) } inserting_chunks.update(chunks) if self.progress_bar is not None: self.progress_bar( cur_index / doc_number, f"Chunking {doc_key}" ) cur_index += 1 _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys())) inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys} elif data_type == "chunked": assert isinstance(data, list) and isinstance(data[0], list) new_docs = { compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])} for doc in data for chunk in doc } _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if len(new_docs) == 0: logger.warning("All docs are already in the storage") return {} logger.info("[New Docs] inserting %d docs", len(new_docs)) async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"): doc_str = "".join([chunk['content'] for chunk in doc]) for chunk in doc: chunk_key = compute_content_hash(chunk['content'], prefix="chunk-") inserting_chunks[chunk_key] = { **chunk, 'full_doc_id': compute_content_hash(doc_str, prefix="doc-") } _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys())) inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys} await self.full_docs_storage.upsert(new_docs) await self.text_chunks_storage.upsert(inserting_chunks) return inserting_chunks def insert(self, data: Union[List[list], List[dict]], data_type: str): loop = create_event_loop() loop.run_until_complete(self.async_insert(data, data_type)) async def async_insert(self, data: Union[List[list], List[dict]], data_type: str): """ insert chunks into the graph """ inserting_chunks = await self.async_split_chunks(data, data_type) if len(inserting_chunks) == 0: logger.warning("All chunks are already in the storage") return logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) logger.info("[Entity and Relation Extraction]...") _add_entities_and_relations = await extract_kg( llm_client=self.synthesizer_llm_client, kg_instance=self.graph_storage, tokenizer_instance=self.tokenizer_instance, chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()], progress_bar = self.progress_bar, ) if not _add_entities_and_relations: logger.warning("No entities or relations extracted") return logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled') if self.if_web_search: logger.info("[Wiki Search]...") _add_wiki_data = await search_wikipedia( llm_client= self.synthesizer_llm_client, wiki_search_client=self.wiki_client, knowledge_graph_instance=_add_entities_and_relations ) await self.wiki_storage.upsert(_add_wiki_data) await self._insert_done() async def _insert_done(self): tasks = [] for storage_instance in [self.full_docs_storage, self.text_chunks_storage, self.graph_storage, self.wiki_storage]: if storage_instance is None: continue tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) await asyncio.gather(*tasks) def quiz(self, max_samples=1): loop = create_event_loop() loop.run_until_complete(self.async_quiz(max_samples)) async def async_quiz(self, max_samples=1): await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples) await self.rephrase_storage.index_done_callback() def judge(self, re_judge=False, skip=False): loop = create_event_loop() loop.run_until_complete(self.async_judge(re_judge, skip)) async def async_judge(self, re_judge=False, skip=False): if skip: _update_relations = await skip_judge_statement(self.graph_storage) else: _update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage, self.rephrase_storage, re_judge) await _update_relations.index_done_callback() def traverse(self): loop = create_event_loop() loop.run_until_complete(self.async_traverse()) async def async_traverse(self): if self.traverse_strategy.qa_form == "atomic": results = await traverse_graph_atomically(self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, self.traverse_strategy, self.text_chunks_storage, self.progress_bar) elif self.traverse_strategy.qa_form == "multi_hop": results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, self.traverse_strategy, self.text_chunks_storage, self.progress_bar) elif self.traverse_strategy.qa_form == "aggregated": results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, self.traverse_strategy, self.text_chunks_storage, self.progress_bar) else: raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}") await self.qa_storage.upsert(results) await self.qa_storage.index_done_callback() def clear(self): loop = create_event_loop() loop.run_until_complete(self.async_clear()) async def async_clear(self): await self.full_docs_storage.drop() await self.text_chunks_storage.drop() await self.wiki_storage.drop() await self.graph_storage.clear() await self.rephrase_storage.drop() await self.qa_storage.drop() logger.info("All caches are cleared")