File size: 5,753 Bytes
b5b6a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Author : Justin 
# Program : Vectorizer for Hybrid Search
# Instructions : Check README.md
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from qdrant_client import models
import logging
import json

# --- Setup Logging ---
# Configure logging to be more descriptive
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)

# --- Configuration ---
# Local models for vector generation
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Global Variables for Models ---
# These will be loaded once when the application starts
dense_model = None
splade_tokenizer = None
splade_model = None

# --- FastAPI Application ---
app = FastAPI(
    title="Hybrid Vector Generation API",
    description="An API to generate dense and sparse vectors for a given text query.",
    version="1.2.0"
)

# --- Pydantic Models for API ---
class QueryRequest(BaseModel):
    """Request model for the API, expecting a single text query."""
    query_text: str

class SparseVectorResponse(BaseModel):
    """Response model for the sparse vector."""
    indices: list[int]
    values: list[float]

class VectorResponse(BaseModel):
    """Final JSON response model containing both vectors."""
    dense_vector: list[float]
    sparse_vector: SparseVectorResponse


@app.on_event("startup")
async def load_models():
    """
    Asynchronous event to load ML models on application startup.
    This ensures models are loaded only once.
    """
    global dense_model, splade_tokenizer, splade_model
    logger.info("Server is starting up... Time to load the ML models.")
    logger.info(f"I'll be using the '{DEVICE}' for processing.")
    try:
        dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
        splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
        splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
        logger.info("YAaay! All models have been loaded successfully.")
    except Exception as e:
        logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True)
        # In a real-world scenario, you might want the app to fail startup if models don't load
        raise e

def compute_splade_vector(text: str) -> models.SparseVector:
    """
    Computes a SPLADE sparse vector from a given text query.
    
    Args:
        text: The input text string.
    Returns:
        A Qdrant SparseVector object.
    """
    tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device

    with torch.no_grad():
        output = splade_model(**tokens)
    
    logits, attention_mask = output.logits, tokens['attention_mask']
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    vec = max_val.squeeze()

    indices = vec.nonzero().squeeze().cpu().tolist()
    values = vec[indices].cpu().tolist()
    
    # Ensure indices and values are always lists, even for a single-element tensor
    if not isinstance(indices, list):
        indices = [indices]
        values = [values]
        
    return models.SparseVector(indices=indices, values=values)


@app.post("/vectorize", response_model=VectorResponse)
async def vectorize_query(request: QueryRequest):
    """
    API endpoint to generate and return dense and sparse vectors for a text query.
    
    Args:
        request: A QueryRequest object containing the 'query_text'.
        
    Returns:
        A JSON response containing the dense and sparse vectors.
    """
    # --- n8n Logging ---
    logger.info("=========================================================")
    logger.info("A new request just arrived! Let's see what we've got.")
    logger.info(f"The incoming search query from n8n is: '{request.query_text}'")
    
    # 1. Generate Dense Vector
    logger.info("First, generating the dense vector for semantic meaning...")
    dense_query_vector = dense_model.encode(request.query_text).tolist()
    logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector))
    logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4]))

    # 2. Generate Sparse Vector
    logger.info("Next up, creating the sparse vector for keyword matching...")
    sparse_query_vector = compute_splade_vector(request.query_text)
    logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices))
    logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4]))

    # 3. Construct and return the response
    logger.info("Everything looks good. Packaging up the vectors to send back.")
    logger.info("-----------------------------------------------------------------")
    
    final_response = VectorResponse(
        dense_vector=dense_query_vector,
        sparse_vector=SparseVectorResponse(
            indices=sparse_query_vector.indices,
            values=sparse_query_vector.values
        )
    )
    return final_response

@app.get("/", include_in_schema=False)
async def root():
    return {"message": "Vector Generation API is running. -- VERSION 2 --"}