|
from pymilvus import MilvusClient, WeightedRanker, AnnSearchRequest |
|
from langchain_ollama import OllamaEmbeddings |
|
|
|
class MilvusRetriever: |
|
def __init__(self, uri): |
|
self.uri = uri |
|
self.embed_model = OllamaEmbeddings(model="bge-m3") |
|
self.client = MilvusClient(self.uri) |
|
|
|
def search(self, query, collection_name, top_k=10): |
|
|
|
|
|
"""_summary_ |
|
|
|
Args: |
|
query (_type_): query string |
|
collection_name (_type_): milvus_collection_name |
|
top_k (int, optional): Top k results. Defaults to 10. |
|
|
|
Returns: |
|
[{"id", "distance", "entity"}] |
|
""" |
|
query_embedding = self.embed_model.embed_query(query) |
|
|
|
if collection_name == "t_sur_sex_ed_article_spider": |
|
return self.article_search(query_embedding, collection_name, top_k=top_k) |
|
|
|
if collection_name == "t_sur_sex_ed_question_answer_spider": |
|
return self.qa_search(query_embedding, collection_name, top_k=top_k) |
|
|
|
if collection_name == "t_sur_sex_ed_youtube_spider": |
|
return self.video_search(query_embedding, collection_name, top_k=top_k) |
|
|
|
def article_search(self, embedding, collection_name, top_k): |
|
search_param1 = { |
|
"data": [embedding], |
|
"anns_field": "chunk_vector", |
|
"param": { |
|
"metric_type": "COSINE", |
|
"params": {"nprobe": 10} |
|
}, |
|
"limit": top_k |
|
} |
|
search_param2 = { |
|
"data": [embedding], |
|
"anns_field": "title_vector", |
|
"param": { |
|
"metric_type": "COSINE", |
|
"params": {"nprobe": 10} |
|
}, |
|
"limit": top_k |
|
} |
|
search_param3 = { |
|
"data": [embedding], |
|
"anns_field": "tags", |
|
"param": { |
|
"metric_type": "COSINE", |
|
"params": {"nprobe": 10} |
|
}, |
|
"limit": top_k |
|
} |
|
rerank = WeightedRanker(0.6, 0.3, 0.1) |
|
r1, r2, r3 = AnnSearchRequest(**search_param1), AnnSearchRequest(**search_param2), AnnSearchRequest(**search_param3) |
|
candidates = [r1, r2, r3] |
|
res = self.client.hybrid_search( |
|
collection_name=collection_name, |
|
ranker=rerank, |
|
reqs=candidates, |
|
limit=top_k, |
|
output_fields=["title", "link", "chunk", "category"] |
|
)[0] |
|
return res |
|
|
|
def qa_search(self, embedding, collection_name, top_k): |
|
res = self.client.search( |
|
collection_name=collection_name, |
|
data=[embedding], |
|
anns_field="title_vector", |
|
search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, |
|
limit=top_k, |
|
filter="content_type == 'A'", |
|
output_fields=["title", "content", "url", "author", "avatar_url", "likes", "dislikes"] |
|
)[0] |
|
|
|
titles = [] |
|
result = [] |
|
for record in res: |
|
if record["entity"]["title"] not in titles: |
|
titles.append(record["entity"]["title"]) |
|
result.append(record) |
|
return result |
|
|
|
def video_search(self, embedding, collection_name, top_k): |
|
res = self.client.search( |
|
collection_name=collection_name, |
|
data=[embedding], |
|
anns_field="title_vector", |
|
search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, |
|
filter="delete_status == 0", |
|
limit=top_k, |
|
output_fields=["title", "link", "author", "picture", "duration"] |
|
)[0] |
|
return res |
|
|
|
def porn_search(self, embedding, collection_name, top_k): |
|
pass |
|
|
|
if __name__ == "__main__": |
|
import json |
|
retriever = MilvusRetriever(uri="http://localhost:19530") |
|
colleciton_name = "t_sur_sex_ed_article_spider" |
|
query = "How to build trust?" |
|
res = retriever.search(query, colleciton_name, top_k=10) |
|
res = [record["entity"] for record in res if record["distance"] > 0.3] |
|
print(json.dumps(res)) |
|
|