Justin44 commited on
Commit
2698809
·
verified ·
1 Parent(s): 0fdc364

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
6
+ from qdrant_client import models
7
+ import logging
8
+
9
+ # --- Setup Logging ---
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # --- Configuration ---
14
+ # Local models for vector generation
15
+ DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
16
+ # Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
17
+ SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
18
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ # --- Global Variables for Models ---
21
+ # These will be loaded once when the application starts
22
+ dense_model = None
23
+ splade_tokenizer = None
24
+ splade_model = None
25
+
26
+ # --- FastAPI Application ---
27
+ app = FastAPI(
28
+ title="Hybrid Vector Generation API",
29
+ description="An API to generate dense and sparse vectors for a given text query.",
30
+ version="1.0.0"
31
+ )
32
+
33
+ # --- Pydantic Models for API ---
34
+ class QueryRequest(BaseModel):
35
+ """Request model for the API, expecting a single text query."""
36
+ query_text: str
37
+
38
+ class SparseVectorResponse(BaseModel):
39
+ """Response model for the sparse vector."""
40
+ indices: list[int]
41
+ values: list[float]
42
+
43
+ class VectorResponse(BaseModel):
44
+ """The final JSON response model containing both vectors."""
45
+ dense_vector: list[float]
46
+ sparse_vector: SparseVectorResponse
47
+
48
+
49
+ @app.on_event("startup")
50
+ async def load_models():
51
+ """
52
+ Asynchronous event to load ML models on application startup.
53
+ This ensures models are loaded only once.
54
+ """
55
+ global dense_model, splade_tokenizer, splade_model
56
+ logger.info(f"Loading models onto device: {DEVICE}")
57
+ try:
58
+ dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
59
+ splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
60
+ splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
61
+ logger.info("Models initialized successfully.")
62
+ except Exception as e:
63
+ logger.fatal(f"FATAL: Could not initialize models. Error: {e}")
64
+ # In a real-world scenario, you might want the app to fail startup if models don't load.
65
+ raise e
66
+
67
+ def compute_splade_vector(text: str) -> models.SparseVector:
68
+ """
69
+ Computes a SPLADE sparse vector from a given text query.
70
+
71
+ Args:
72
+ text: The input text string.
73
+ Returns:
74
+ A Qdrant SparseVector object.
75
+ """
76
+ tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
77
+ tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device
78
+
79
+ with torch.no_grad():
80
+ output = splade_model(**tokens)
81
+
82
+ logits, attention_mask = output.logits, tokens.attention_mask
83
+ relu_log = torch.log(1 + torch.relu(logits))
84
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
85
+ max_val, _ = torch.max(weighted_log, dim=1)
86
+ vec = max_val.squeeze()
87
+
88
+ indices = vec.nonzero().squeeze().cpu().tolist()
89
+ values = vec[indices].cpu().tolist()
90
+
91
+ # Ensure indices and values are always lists, even for a single-element tensor
92
+ if not isinstance(indices, list):
93
+ indices = [indices]
94
+ values = [values]
95
+
96
+ return models.SparseVector(indices=indices, values=values)
97
+
98
+
99
+ @app.post("/vectorize", response_model=VectorResponse)
100
+ async def vectorize_query(request: QueryRequest):
101
+ """
102
+ API endpoint to generate and return dense and sparse vectors for a text query.
103
+
104
+ Args:
105
+ request: A QueryRequest object containing the 'query_text'.
106
+
107
+ Returns:
108
+ A JSON response containing the dense and sparse vectors.
109
+ """
110
+ logger.info(f"Received query for vectorization: '{request.query_text}'")
111
+
112
+ # 1. Generate Dense Vector
113
+ logger.info("Generating dense vector...")
114
+ dense_query_vector = dense_model.encode(request.query_text).tolist()
115
+ logger.info("Dense vector generated.")
116
+
117
+ # 2. Generate Sparse Vector
118
+ logger.info("Generating sparse vector...")
119
+ sparse_query_vector = compute_splade_vector(request.query_text)
120
+ logger.info("Sparse vector generated.")
121
+
122
+ # 3. Construct and return the response
123
+ return VectorResponse(
124
+ dense_vector=dense_query_vector,
125
+ sparse_vector=SparseVectorResponse(
126
+ indices=sparse_query_vector.indices,
127
+ values=sparse_query_vector.values
128
+ )
129
+ )
130
+
131
+ @app.get("/", include_in_schema=False)
132
+ async def root():
133
+ return {"message": "Vector Generation API is running. POST to /vectorize to get vectors."}