File size: 4,658 Bytes
2698809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from qdrant_client import models
import logging

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configuration ---
# Local models for vector generation
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Global Variables for Models ---
# These will be loaded once when the application starts
dense_model = None
splade_tokenizer = None
splade_model = None

# --- FastAPI Application ---
app = FastAPI(
    title="Hybrid Vector Generation API",
    description="An API to generate dense and sparse vectors for a given text query.",
    version="1.0.0"
)

# --- Pydantic Models for API ---
class QueryRequest(BaseModel):
    """Request model for the API, expecting a single text query."""
    query_text: str

class SparseVectorResponse(BaseModel):
    """Response model for the sparse vector."""
    indices: list[int]
    values: list[float]

class VectorResponse(BaseModel):
    """The final JSON response model containing both vectors."""
    dense_vector: list[float]
    sparse_vector: SparseVectorResponse


@app.on_event("startup")
async def load_models():
    """
    Asynchronous event to load ML models on application startup.
    This ensures models are loaded only once.
    """
    global dense_model, splade_tokenizer, splade_model
    logger.info(f"Loading models onto device: {DEVICE}")
    try:
        dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
        splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
        splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
        logger.info("Models initialized successfully.")
    except Exception as e:
        logger.fatal(f"FATAL: Could not initialize models. Error: {e}")
        # In a real-world scenario, you might want the app to fail startup if models don't load.
        raise e

def compute_splade_vector(text: str) -> models.SparseVector:
    """
    Computes a SPLADE sparse vector from a given text query.
    
    Args:
        text: The input text string.
    Returns:
        A Qdrant SparseVector object.
    """
    tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device

    with torch.no_grad():
        output = splade_model(**tokens)
    
    logits, attention_mask = output.logits, tokens.attention_mask
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    vec = max_val.squeeze()

    indices = vec.nonzero().squeeze().cpu().tolist()
    values = vec[indices].cpu().tolist()
    
    # Ensure indices and values are always lists, even for a single-element tensor
    if not isinstance(indices, list):
        indices = [indices]
        values = [values]
        
    return models.SparseVector(indices=indices, values=values)


@app.post("/vectorize", response_model=VectorResponse)
async def vectorize_query(request: QueryRequest):
    """
    API endpoint to generate and return dense and sparse vectors for a text query.
    
    Args:
        request: A QueryRequest object containing the 'query_text'.
        
    Returns:
        A JSON response containing the dense and sparse vectors.
    """
    logger.info(f"Received query for vectorization: '{request.query_text}'")

    # 1. Generate Dense Vector
    logger.info("Generating dense vector...")
    dense_query_vector = dense_model.encode(request.query_text).tolist()
    logger.info("Dense vector generated.")

    # 2. Generate Sparse Vector
    logger.info("Generating sparse vector...")
    sparse_query_vector = compute_splade_vector(request.query_text)
    logger.info("Sparse vector generated.")

    # 3. Construct and return the response
    return VectorResponse(
        dense_vector=dense_query_vector,
        sparse_vector=SparseVectorResponse(
            indices=sparse_query_vector.indices,
            values=sparse_query_vector.values
        )
    )

@app.get("/", include_in_schema=False)
async def root():
    return {"message": "Vector Generation API is running. POST to /vectorize to get vectors."}