# Author : Justin # Program : Vectorizer for Hybrid Search # Instructions : Check README.md import torch from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForMaskedLM from qdrant_client import models import logging import json # --- Setup Logging --- # Configure logging to be more descriptive logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', ) logger = logging.getLogger(__name__) # --- Configuration --- # Local models for vector generation DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' # Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query' DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # --- Global Variables for Models --- # These will be loaded once when the application starts dense_model = None splade_tokenizer = None splade_model = None # --- FastAPI Application --- app = FastAPI( title="Hybrid Vector Generation API", description="An API to generate dense and sparse vectors for a given text query.", version="1.2.0" ) # --- Pydantic Models for API --- class QueryRequest(BaseModel): """Request model for the API, expecting a single text query.""" query_text: str class SparseVectorResponse(BaseModel): """Response model for the sparse vector.""" indices: list[int] values: list[float] class VectorResponse(BaseModel): """Final JSON response model containing both vectors.""" dense_vector: list[float] sparse_vector: SparseVectorResponse @app.on_event("startup") async def load_models(): """ Asynchronous event to load ML models on application startup. This ensures models are loaded only once. """ global dense_model, splade_tokenizer, splade_model logger.info("Server is starting up... Time to load the ML models.") logger.info(f"I'll be using the '{DEVICE}' for processing.") try: dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE) splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID) splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE) logger.info("YAaay! All models have been loaded successfully.") except Exception as e: logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True) # In a real-world scenario, you might want the app to fail startup if models don't load raise e def compute_splade_vector(text: str) -> models.SparseVector: """ Computes a SPLADE sparse vector from a given text query. Args: text: The input text string. Returns: A Qdrant SparseVector object. """ tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device with torch.no_grad(): output = splade_model(**tokens) logits, attention_mask = output.logits, tokens['attention_mask'] relu_log = torch.log(1 + torch.relu(logits)) weighted_log = relu_log * attention_mask.unsqueeze(-1) max_val, _ = torch.max(weighted_log, dim=1) vec = max_val.squeeze() indices = vec.nonzero().squeeze().cpu().tolist() values = vec[indices].cpu().tolist() # Ensure indices and values are always lists, even for a single-element tensor if not isinstance(indices, list): indices = [indices] values = [values] return models.SparseVector(indices=indices, values=values) @app.post("/vectorize", response_model=VectorResponse) async def vectorize_query(request: QueryRequest): """ API endpoint to generate and return dense and sparse vectors for a text query. Args: request: A QueryRequest object containing the 'query_text'. Returns: A JSON response containing the dense and sparse vectors. """ # --- n8n Logging --- logger.info("=========================================================") logger.info("A new request just arrived! Let's see what we've got.") logger.info(f"The incoming search query from n8n is: '{request.query_text}'") # 1. Generate Dense Vector logger.info("First, generating the dense vector for semantic meaning...") dense_query_vector = dense_model.encode(request.query_text).tolist() logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector)) logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4])) # 2. Generate Sparse Vector logger.info("Next up, creating the sparse vector for keyword matching...") sparse_query_vector = compute_splade_vector(request.query_text) logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices)) logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4])) # 3. Construct and return the response logger.info("Everything looks good. Packaging up the vectors to send back.") logger.info("-----------------------------------------------------------------") final_response = VectorResponse( dense_vector=dense_query_vector, sparse_vector=SparseVectorResponse( indices=sparse_query_vector.indices, values=sparse_query_vector.values ) ) return final_response @app.get("/", include_in_schema=False) async def root(): return {"message": "Vector Generation API is running. -- VERSION 2 --"}