SexBot / utils /recursive_retrieve.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
from llama_index.core.schema import IndexNode
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, SummaryIndex, load_index_from_storage, StorageContext, Document
from llama_index.core.callbacks import LlamaDebugHandler, CallbackManager
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
import Stemmer
from typing import List, Dict, Optional
import os
from pathlib import Path
import chromadb
# Global
llm = Ollama(model="pornchat", base_url="http://localhost:11434", request_timeout=240)
embed_model = OllamaEmbeddings(model="pornchat", base_url="http://localhost:11434")
Settings.llm = llm
Settings.embed_model = embed_model
splitter = SentenceSplitter()
callback_manager = CallbackManager([LlamaDebugHandler()])
test_data_dir = "/data1/home/purui/projects/chatbot/tests/data/txt"
test_index_dir = "/data1/home/purui/projects/chatbot/tests/kb"
data_dir = "/data1/home/purui/projects/chatbot/data/txt"
index_dir = "/data1/home/purui/projects/chatbot/kb"
def get_file_name(file_dir):
files = []
paths = os.listdir(file_dir)
for file in paths:
if os.path.isfile(os.path.join(file_dir, file)):
file_name, _ = os.path.splitext(file)
files.append(file_name)
return files
def get_dir_name(file_dir):
dirs = []
paths = os.listdir(file_dir)
for path in paths:
if os.path.isdir(os.path.join(file_dir, path)):
dir_name,_ = os.path.splitext(path)
dirs.append(dir_name)
return dirs
# 加载index data_type: blog, q&a
def prepare_nodes(file_dir, index_dir, data_type, chroma_path):
"""
file_dir: data/txt/(data_type)
index_dir: kb
data_type: blog, qa, etc.
"""
nodes = []
docs_dict = {}
if data_type == "qa":
file_count = 0
# preprocess file
titles = get_file_name(file_dir)
for title in titles:
answers = []
topic_answers = ""
original_question = ""
with open(f"{file_dir}/{title}.txt") as f:
# get original question
for line in f:
if line.startswith("Q:"):
original_question = line.split(":")[-1].strip(" ")
break
# get answers
for line in f:
if line.startswith("A:"):
answer = line.split(":")[-1].strip(" ")
answers.append(answer)
# answers for one question
topic_answers = "\n".join(answers)
# create document
doc = Document(text=topic_answers)
docs_dict[title] = doc
if doc.text == "":
continue
# create index
index_path = f"{index_dir}/{title}"
if not os.path.exists(index_path):
vector_index = VectorStoreIndex.from_documents(
documents=[docs_dict[title]],
transformations=[splitter],
callback_manager=callback_manager
)
vector_index.storage_context.persist(persist_dir=index_path)
# save index in vectorstore
db = chromadb.PersistentClient(path=chroma_path)
collection = db.get_or_create_collection(name=f"file_{file_count}")
vector_store = ChromaVectorStore(chroma_collection=collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
chroma_vector_index = VectorStoreIndex.from_documents(
documents=[docs_dict[title]],
storage_context=storage_context,
embed_model=embed_model,
show_progress=True,
)
# create top_index's node
out_path = Path(f"{index_dir}/summaries/{title}")
if not out_path.exists():
with open(out_path, "w") as f:
f.write(f"This is some answers about {original_question}")
node = IndexNode(text=original_question, index_id=title)
nodes.append(node)
file_count += 1
if data_type == "blog":
file_count = 0
titles = get_file_name(file_dir)
for title in titles:
doc = SimpleDirectoryReader(
input_files=[f"{file_dir}/{title}.txt"]
).load_data()[0]
docs_dict[title] = doc
for title in titles:
index_path = f"{index_dir}/{title}"
if not os.path.exists(index_path):
# create index
vector_index = VectorStoreIndex.from_documents(
[docs_dict[title]],
transformations=[splitter],
callback_manager=callback_manager
)
vector_index.storage_context.persist(persist_dir=index_path)
# save index in vectorstore
db = chromadb.PersistentClient(path=chroma_path)
collection = db.get_or_create_collection(name=f"file_{file_count}")
vector_store = ChromaVectorStore(chroma_collection=collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
chroma_vector_index = VectorStoreIndex.from_documents(
documents=[docs_dict[title]],
storage_context=storage_context,
embed_model=embed_model,
show_progress=True,
)
out_path = Path(f"{index_dir}/summaries/{title}")
summary = f"This is a article about {title}"
if not out_path.exists():
# summary_index = SummaryIndex.from_documents(
# [docs_dict[title]], callback_manager=callback_manager
# )
# summarizer = summary_index.as_query_engine(
# reponse_mode="tree_summarize", llm=llm
# )
# response = summarizer.query(f"Give a summary of {title}")
Path(f"{index_dir}/summaries").mkdir(exist_ok=True)
with open(out_path, "w") as f:
f.write(summary)
node = IndexNode(text=summary, index_id=title)
nodes.append(node)
file_count += 1
return nodes
def create_top_index(data_dir, index_dir):
# data_dir分级 (blog, qa, etc.)
data_types = []
all_nodes = []
for dir in os.listdir(data_dir):
if os.path.isdir(f"{data_dir}/{dir}"):
data_types.append(dir)
for data_type in data_types:
nodes = prepare_nodes(f"{data_dir}/{data_type}", index_dir, data_type=data_type)
all_nodes.extend(nodes)
index_dir = f"{index_dir}/top_index"
# vector top index
if not os.path.exists(index_dir):
# create index
top_vector_index = VectorStoreIndex(nodes=all_nodes)
top_vector_index.storage_context.persist(persist_dir=index_dir)
else:
# load and insert
top_vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=index_dir))
top_vector_index.insert_nodes(nodes=all_nodes)
# bm25
bm25_retriever = BM25Retriever.from_defaults(
nodes=all_nodes,
similarity_top_k=2,
stemmer=Stemmer.Stemmer("english"),
language="english"
)
bm25_retriever.persist(f"{index_dir}/bm25_retriever")
def create_top_index_chroma(data_dir, index_dir):
# data_dir分级 (blog, qa, etc.)
data_types = []
all_nodes = []
chroma_path = f"{index_dir}/chroma"
for dir in os.listdir(data_dir):
if os.path.isdir(f"{data_dir}/{dir}"):
data_types.append(dir)
for data_type in data_types:
nodes = prepare_nodes(f"{data_dir}/{data_type}", index_dir, data_type=data_type, chroma_path=chroma_path)
all_nodes.extend(nodes)
index_dir = f"{index_dir}/chroma/top_index"
db = chromadb.PersistentClient(path=index_dir)
chroma_collection = db.get_or_create_collection(name="top_index")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
if not os.path.exists(index_dir):
# create index
top_vector_index = VectorStoreIndex(nodes=all_nodes, storage_context=StorageContext.from_defaults(vector_store=vector_store))
else:
# load index
top_vector_index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
)
# bm25
bm25_retriever = BM25Retriever.from_defaults(
nodes=all_nodes,
similarity_top_k=2,
stemmer=Stemmer.Stemmer("english"),
language="english"
)
bm25_retriever.persist(f"{index_dir}/bm25_retriever")
def get_recursive_retriever(data_dir, index_dir):
top_vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/top_index"))
data_types = []
for dir in os.listdir(data_dir):
sub_dir = f"{data_dir}/{dir}"
if os.path.isdir(sub_dir):
data_types.append(sub_dir)
vector_retrievers = {}
for data_type in data_types:
titles = get_file_name(data_type)
for title in titles:
persistent_dir = f"{index_dir}/{title}"
if os.path.exists(persistent_dir):
vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/{title}"))
vector_retriever = vector_index.as_retriever(similarity_top_k=3)
vector_retrievers[title] = vector_retriever
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": top_vector_index.as_retriever(simliarity_top_k=5), **vector_retrievers},
verbose=True,
)
return recursive_retriever
def get_bm25_recursive_retriever(data_dir, index_dir):
retriever = BM25Retriever.from_persist_dir(f"{index_dir}/top_index/bm25_retriever")
data_types = []
for dir in os.listdir(data_dir):
sub_dir = f"{data_dir}/{dir}"
if os.path.isdir(sub_dir):
data_types.append(sub_dir)
vector_retrievers = {}
for data_type in data_types:
titles = get_file_name(data_type)
for title in titles:
persistent_dir = f"{index_dir}/{title}"
if os.path.exists(persistent_dir):
vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/{title}"))
vector_retriever = vector_index.as_retriever(similarity_top_k=3)
vector_retrievers[title] = vector_retriever
recursive_retriever = RecursiveRetriever(
"bm25",
retriever_dict={"bm25": retriever, **vector_retrievers},
verbose=True,
)
return recursive_retriever
def get_hybrid_recursive_retriever(data_dir, index_dir):
bm25_retriever = BM25Retriever.from_persist_dir(f"{index_dir}/top_index/bm25_retriever")
vector_retriever = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/top_index")).as_retriever(similarity_top_k=2)
retriever = QueryFusionRetriever(
retrievers=[bm25_retriever, vector_retriever],
similarity_top_k=2,
num_queries=1,
mode="reciprocal_rerank",
use_async=False,
verbose=True,
)
data_types = []
for dir in os.listdir(data_dir):
sub_dir = f"{data_dir}/{dir}"
if os.path.isdir(sub_dir):
data_types.append(sub_dir)
vector_retrievers = {}
for data_type in data_types:
titles = get_file_name(data_type)
for title in titles:
persistent_dir = f"{index_dir}/{title}"
if os.path.exists(persistent_dir):
vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/{title}"))
vector_retriever = vector_index.as_retriever(similarity_top_k=1)
vector_retrievers[title] = vector_retriever
recursive_retriever = RecursiveRetriever(
"hybrid",
retriever_dict={"hybrid": retriever, **vector_retrievers},
verbose=True,
)
return recursive_retriever
if __name__ == "__main__":
# create_top_index(data_dir="/data1/home/purui/projects/chatbot/data/txt", index_dir="/data1/home/purui/projects/chatbot/kb")
# top_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="/data1/home/purui/projects/chatbot/tests/kb/top_index"))
# retriever = top_index.as_retriever(similarity_top_k=2)
# nodes = retriever.retrieve("My girlfriend dont want sex. What should I do?")
# print(nodes)
# recursive_retriever = get_recursive_retriever(data_dir="/data1/home/purui/projects/chatbot/tests/data/txt", index_dir="/data1/home/purui/projects/chatbot/tests/kb")
# nodes = recursive_retriever.retrieve("what stages will I experience during the orgasm?")
# print(nodes)
# bm25_recursive_retriever = get_bm25_recursive_retriever(data_dir="/data1/home/purui/projects/chatbot/data/txt", index_dir="/data1/home/purui/projects/chatbot/kb")
# bm25_recursive_retriever.retrieve("How to give a good blowjob?")
# import nest_asyncio
# nest_asyncio.apply()
# hybrid_recursive_retriever = get_hybrid_recursive_retriever(data_dir="/data1/home/purui/projects/chatbot/data/txt", index_dir="/data1/home/purui/projects/chatbot/kb")
# hybrid_recursive_retriever.retrieve("How to give a good blowjob?")
# index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="/data1/home/purui/projects/chatbot/kb/Intercourse feels strange"))
# nodes = index._get_node_with_embedding()
# print(nodes)
create_top_index_chroma(data_dir="/data1/home/purui/projects/chatbot/tests/data/txt", index_dir="/data1/home/purui/projects/chatbot/tests/kb")