File size: 3,188 Bytes
a71293a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build
from typing import Dict
import sys

model_image = (Image.debian_slim(python_version="3.12")
              .pip_install("chromadb", "sentence-transformers", "pysqlite3-binary")
)

# Utilities
with model_image.imports():
    import os
    import numpy as np
    __import__("pysqlite3")
    sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") # Hotswap SQLlite version

# Application initialization
app = App("mps-api",
          image=model_image)
vol = Volume.from_name("mps", create_if_missing=False)
data_path = "/data"

############
# MAIN CLASS
############
@app.cls(timeout=30*60,
         volumes={data_path: vol})
class VECTORDB:
    @enter()
    @build()
    def init(self):
        # Load encoder
        from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
        model_name = "Lajavaness/sentence-camembert-large"
        self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
        print(f"Embedding model loaded: {model_name}")

        # Load vector database
        import chromadb
        DB_PATH = data_path + "/db"
        COLLECTION_NAME = "MPS"
        chroma_client = chromadb.PersistentClient(path=DB_PATH)
        self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function)
        print(f"{self.chroma_collection.count()} documents loaded.")

    @method()
    def search(self, queries, origins, n_results=10):
        results = self.chroma_collection.query(
            query_texts=queries,
            n_results=n_results,
            where={"origin": {"$in": origins}},
            include=['documents', 'metadatas', 'distances'])

        documents = results['documents']
        metadatas = results['metadatas']
        distances = results['distances']
        return documents, metadatas, distances

@app.cls(timeout=30*60)
class RANKING:
    @enter()
    @build()
    def init(self):
        # Load crossencoder
        from sentence_transformers import CrossEncoder
        model_name = "Lajavaness/CrossEncoder-camembert-large"
        self.cross_encoder = CrossEncoder(model_name)
        print(f"Cross encoder model loaded: {model_name}")

    @method()
    def rank(self, query, documents):
        pairs = [[query, doc] for doc in documents]
        scores = self.cross_encoder.predict(pairs)
        ranking = np.argsort(scores)[::-1].tolist()
        return ranking

###########
# ENDPOINTS
###########
@app.function(timeout=30*60)
@web_endpoint(method="POST")
def retrieve(query: Dict):
    # Log query
    print(f"Retrieve query: {query}...")

    # Searching documents
    documents, metadatas, distances = VECTORDB().search.remote(query['query'], query['origins'], query['n_results'])
    return {"documents" : documents, "metadatas" : metadatas, "distances" : distances}

@app.function(timeout=30*60)
@web_endpoint(method="POST")
def rank(query: Dict):
    # Log query
    print(f"Rank query: {query}...")

    # Ranking documents
    ranking = RANKING().rank.remote(query['query'], query['documents'])

    return {"ranking" : ranking}