SexBot / recursive_retriever.py
Pew404's picture
Upload folder using huggingface_hub
13fbd2e verified
from llama_index.core.schema import IndexNode
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, SummaryIndex
from llama_index.core.callbacks import LlamaDebugHandler, CallbackManager
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings
import os
from pathlib import Path
llm = Ollama(model="pornchat", base_url="http://localhost:11434", request_timeout=120)
embed_model = OllamaEmbeddings(model="pornchat", base_url="http://localhost:11434")
Settings.llm = llm
Settings.embed_model = embed_model
splitter = SentenceSplitter()
callback_manager = CallbackManager([LlamaDebugHandler()])
# 处理文件
# 1. 获取文件名称
def get_file_name(file_dir):
files = []
paths = os.listdir(file_dir)
for file in paths:
if os.path.isfile(os.path.join(file_dir, file)):
file_name = file.split(".")[0]
files.append(file_name)
return files
data_path = "/data1/home/purui/projects/chatbot/tests/data/txt"
titles = get_file_name(data_path)
# 2. 读取文件内容
docs_dict = {}
for title in titles:
doc = SimpleDirectoryReader(
input_files=[f"{data_path}/{title}.txt"]
).load_data()[0]
docs_dict[title] = doc
vector_query_engines = {}
vector_retrievers = {}
nodes = []
for title in titles:
# 创建index / 加载index
vector_index = VectorStoreIndex.from_documents(
[docs_dict[title]],
transformations=[splitter],
callback_manager=callback_manager,
)
# query engine
vector_query_engine = vector_index.as_query_engine()
vector_query_engines[title] = vector_query_engine
vector_retrievers[title] = vector_index.as_retriever()
# summaries
out_path = Path(f"{data_path}/summaries/{title}.txt")
if not out_path.exists():
summary_index = SummaryIndex.from_documents(
[docs_dict[title]], callback_manager=callback_manager
)
summarizer = summary_index.as_query_engine(
response_mode="tree_summarize", llm=llm
)
response = summarizer.query(f"Give a summary of {title}")
summary = response.response
Path(f"{data_path}/summaries").mkdir(exist_ok=True)
with open(out_path, "w") as fp:
fp.write(summary)
else:
with open(out_path, "r") as fp:
summary = fp.read()
# print(f"**Summary for {title}: {summary}")
node = IndexNode(text=summary, index_id=title)
nodes.append(node)
# define recursive retriever
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import get_response_synthesizer
# note: can pass `agents` dict as `query_engine_dict` since every agent can be used as a query engine
# define top-level retriever
top_vector_index = VectorStoreIndex(
nodes, transformations=[splitter], callback_manager=callback_manager
)
top_vector_index.storage_context.persist("/data1/home/purui/projects/chatbot/kb/sexual_knowledge")
top_vector_retriever = top_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": top_vector_retriever, **vector_retrievers},
# query_engine_dict=vector_query_engines,
verbose=True,
)
# run recursive retriever
nodes = recursive_retriever.retrieve(
"How to make good blow job?"
)
for node in nodes:
print(node.node.get_content())
# 存每一个文件的vector index, 存每个文件summary的top_index; 调用时 load_index(index目录) -> 创建top_index -> query