File size: 4,233 Bytes
318db6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from pymilvus import MilvusClient, WeightedRanker, AnnSearchRequest
from langchain_ollama import OllamaEmbeddings

class MilvusRetriever:
    def __init__(self, uri):
        self.uri = uri
        self.embed_model = OllamaEmbeddings(model="bge-m3")
        self.client = MilvusClient(self.uri)

    def search(self, query, collection_name, top_k=10):
        # Connect to Milvus and search for the top_k nearest neighbors to the query_vector
        # in the specified collection.
        """_summary_

        Args:
            query (_type_): query string
            collection_name (_type_): milvus_collection_name
            top_k (int, optional): Top k results. Defaults to 10.

        Returns:
            [{"id", "distance", "entity"}]
        """
        query_embedding = self.embed_model.embed_query(query)
        
        if collection_name == "t_sur_sex_ed_article_spider":
            return self.article_search(query_embedding, collection_name, top_k=top_k)
        
        if collection_name == "t_sur_sex_ed_question_answer_spider":
            return self.qa_search(query_embedding, collection_name, top_k=top_k)
        
        if collection_name == "t_sur_sex_ed_youtube_spider":
            return self.video_search(query_embedding, collection_name, top_k=top_k)
        
    def article_search(self, embedding, collection_name, top_k):
        search_param1 = {
            "data": [embedding],
            "anns_field": "chunk_vector",
            "param": {
                "metric_type": "COSINE", 
                "params": {"nprobe": 10}
            },
            "limit": top_k
        }
        search_param2 = {
            "data": [embedding],
            "anns_field": "title_vector",
            "param": {
                "metric_type": "COSINE", 
                "params": {"nprobe": 10}
            },
            "limit": top_k
        }
        search_param3 = {
            "data": [embedding],
            "anns_field": "tags",
            "param": {
                "metric_type": "COSINE", 
                "params": {"nprobe": 10}
            },
            "limit": top_k
        }
        rerank = WeightedRanker(0.6, 0.3, 0.1)
        r1, r2, r3 = AnnSearchRequest(**search_param1), AnnSearchRequest(**search_param2), AnnSearchRequest(**search_param3)
        candidates = [r1, r2, r3]
        res = self.client.hybrid_search(
            collection_name=collection_name,
            ranker=rerank,
            reqs=candidates,
            limit=top_k,
            output_fields=["title", "link", "chunk", "category"]
        )[0]
        return res

    def qa_search(self, embedding, collection_name, top_k):
        res = self.client.search(
            collection_name=collection_name,
            data=[embedding],
            anns_field="title_vector",
            search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
            limit=top_k,
            filter="content_type == 'A'",
            output_fields=["title", "content", "url", "author", "avatar_url", "likes", "dislikes"]
        )[0]
        # 去重title
        titles = []
        result = []
        for record in res:
            if record["entity"]["title"] not in titles:
                titles.append(record["entity"]["title"])
                result.append(record)
        return result
    
    def video_search(self, embedding, collection_name, top_k):
        res = self.client.search(
            collection_name=collection_name,
            data=[embedding],
            anns_field="title_vector",
            search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
            filter="delete_status == 0",
            limit=top_k,
            output_fields=["title", "link", "author", "picture", "duration"]
        )[0]
        return res

    def porn_search(self, embedding, collection_name, top_k):
        pass

if __name__ == "__main__":
    import json
    retriever = MilvusRetriever(uri="http://localhost:19530")
    colleciton_name = "t_sur_sex_ed_article_spider"
    query = "How to build trust?"
    res = retriever.search(query, colleciton_name, top_k=10)
    res = [record["entity"] for record in res if record["distance"] > 0.3]
    print(json.dumps(res))