SexBot / milvusDB /retriever.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
from pymilvus import MilvusClient, WeightedRanker, AnnSearchRequest
from langchain_ollama import OllamaEmbeddings
class MilvusRetriever:
def __init__(self, uri):
self.uri = uri
self.embed_model = OllamaEmbeddings(model="bge-m3")
self.client = MilvusClient(self.uri)
def search(self, query, collection_name, top_k=10):
# Connect to Milvus and search for the top_k nearest neighbors to the query_vector
# in the specified collection.
"""_summary_
Args:
query (_type_): query string
collection_name (_type_): milvus_collection_name
top_k (int, optional): Top k results. Defaults to 10.
Returns:
[{"id", "distance", "entity"}]
"""
query_embedding = self.embed_model.embed_query(query)
if collection_name == "t_sur_sex_ed_article_spider":
return self.article_search(query_embedding, collection_name, top_k=top_k)
if collection_name == "t_sur_sex_ed_question_answer_spider":
return self.qa_search(query_embedding, collection_name, top_k=top_k)
if collection_name == "t_sur_sex_ed_youtube_spider":
return self.video_search(query_embedding, collection_name, top_k=top_k)
def article_search(self, embedding, collection_name, top_k):
search_param1 = {
"data": [embedding],
"anns_field": "chunk_vector",
"param": {
"metric_type": "COSINE",
"params": {"nprobe": 10}
},
"limit": top_k
}
search_param2 = {
"data": [embedding],
"anns_field": "title_vector",
"param": {
"metric_type": "COSINE",
"params": {"nprobe": 10}
},
"limit": top_k
}
search_param3 = {
"data": [embedding],
"anns_field": "tags",
"param": {
"metric_type": "COSINE",
"params": {"nprobe": 10}
},
"limit": top_k
}
rerank = WeightedRanker(0.6, 0.3, 0.1)
r1, r2, r3 = AnnSearchRequest(**search_param1), AnnSearchRequest(**search_param2), AnnSearchRequest(**search_param3)
candidates = [r1, r2, r3]
res = self.client.hybrid_search(
collection_name=collection_name,
ranker=rerank,
reqs=candidates,
limit=top_k,
output_fields=["title", "link", "chunk", "category"]
)[0]
return res
def qa_search(self, embedding, collection_name, top_k):
res = self.client.search(
collection_name=collection_name,
data=[embedding],
anns_field="title_vector",
search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
limit=top_k,
filter="content_type == 'A'",
output_fields=["title", "content", "url", "author", "avatar_url", "likes", "dislikes"]
)[0]
# 去重title
titles = []
result = []
for record in res:
if record["entity"]["title"] not in titles:
titles.append(record["entity"]["title"])
result.append(record)
return result
def video_search(self, embedding, collection_name, top_k):
res = self.client.search(
collection_name=collection_name,
data=[embedding],
anns_field="title_vector",
search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
filter="delete_status == 0",
limit=top_k,
output_fields=["title", "link", "author", "picture", "duration"]
)[0]
return res
def porn_search(self, embedding, collection_name, top_k):
pass
if __name__ == "__main__":
import json
retriever = MilvusRetriever(uri="http://localhost:19530")
colleciton_name = "t_sur_sex_ed_article_spider"
query = "How to build trust?"
res = retriever.search(query, colleciton_name, top_k=10)
res = [record["entity"] for record in res if record["distance"] > 0.3]
print(json.dumps(res))