SexBot / tool_retrieval.py
Pew404's picture
Upload folder using huggingface_hub
13fbd2e verified
# 获取index名称 --> tools
tools_name = ['alice', 'ReAct', 'BERT', 'sexuality', 'sexual security', 'medical']
# 对tools的名称Embedding到top_index中
from llama_index.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings
from llama_index.core.schema import TextNode
from llama_index.core import VectorStoreIndex
from llama_index.retrievers.bm25 import BM25Retriever
import Stemmer
from llama_index.core import Settings
llm = Ollama(model="llama3")
embed_model = OllamaEmbeddings(model="nomic-embed-text")
Settings.llm = llm
Settings.embed_model = embed_model
nodes = [TextNode(text=tool, id=tool) for tool in tools_name]
# 向量搜索
index = VectorStoreIndex(nodes=nodes, show_progress=True)
# BM25检索
bm25_retriever = BM25Retriever.from_defaults(
nodes=nodes,
similarity_top_k=2,
stemmer=Stemmer.Stemmer("english"),
language="english",
)
# query的关键词match top_k相关的tool名称
tool_retriever = index.as_retriever(similarity_top_k=2)
query = "alice is a good girl or bad girl."
retrieved_tools = [tool.text for tool in tool_retriever.retrieve(query)]
bm25_retrieved_tools = [tool.text for tool in bm25_retriever.retrieve(query)]
print(f"vector_retrieved: {retrieved_tools}")
print(f"bm25_retrieved: {bm25_retrieved_tools}")
# 根据tool名称获取tool