from pymilvus import MilvusClient, WeightedRanker, AnnSearchRequest from langchain_ollama import OllamaEmbeddings class MilvusRetriever: def __init__(self, uri): self.uri = uri self.embed_model = OllamaEmbeddings(model="bge-m3") self.client = MilvusClient(self.uri) def search(self, query, collection_name, top_k=10): # Connect to Milvus and search for the top_k nearest neighbors to the query_vector # in the specified collection. """_summary_ Args: query (_type_): query string collection_name (_type_): milvus_collection_name top_k (int, optional): Top k results. Defaults to 10. Returns: [{"id", "distance", "entity"}] """ query_embedding = self.embed_model.embed_query(query) if collection_name == "t_sur_sex_ed_article_spider": return self.article_search(query_embedding, collection_name, top_k=top_k) if collection_name == "t_sur_sex_ed_question_answer_spider": return self.qa_search(query_embedding, collection_name, top_k=top_k) if collection_name == "t_sur_sex_ed_youtube_spider": return self.video_search(query_embedding, collection_name, top_k=top_k) def article_search(self, embedding, collection_name, top_k): search_param1 = { "data": [embedding], "anns_field": "chunk_vector", "param": { "metric_type": "COSINE", "params": {"nprobe": 10} }, "limit": top_k } search_param2 = { "data": [embedding], "anns_field": "title_vector", "param": { "metric_type": "COSINE", "params": {"nprobe": 10} }, "limit": top_k } search_param3 = { "data": [embedding], "anns_field": "tags", "param": { "metric_type": "COSINE", "params": {"nprobe": 10} }, "limit": top_k } rerank = WeightedRanker(0.6, 0.3, 0.1) r1, r2, r3 = AnnSearchRequest(**search_param1), AnnSearchRequest(**search_param2), AnnSearchRequest(**search_param3) candidates = [r1, r2, r3] res = self.client.hybrid_search( collection_name=collection_name, ranker=rerank, reqs=candidates, limit=top_k, output_fields=["title", "link", "chunk", "category"] )[0] return res def qa_search(self, embedding, collection_name, top_k): res = self.client.search( collection_name=collection_name, data=[embedding], anns_field="title_vector", search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, limit=top_k, filter="content_type == 'A'", output_fields=["title", "content", "url", "author", "avatar_url", "likes", "dislikes"] )[0] # 去重title titles = [] result = [] for record in res: if record["entity"]["title"] not in titles: titles.append(record["entity"]["title"]) result.append(record) return result def video_search(self, embedding, collection_name, top_k): res = self.client.search( collection_name=collection_name, data=[embedding], anns_field="title_vector", search_params={"metric_type": "COSINE", "params": {"nprobe": 10}}, filter="delete_status == 0", limit=top_k, output_fields=["title", "link", "author", "picture", "duration"] )[0] return res def porn_search(self, embedding, collection_name, top_k): pass if __name__ == "__main__": import json retriever = MilvusRetriever(uri="http://localhost:19530") colleciton_name = "t_sur_sex_ed_article_spider" query = "How to build trust?" res = retriever.search(query, colleciton_name, top_k=10) res = [record["entity"] for record in res if record["distance"] > 0.3] print(json.dumps(res))