service-internal's picture
Rename app (1).py to app.py
e998e25 verified
# Author : Justin
# Program : Vectorizer for Hybrid Search
# Instructions : Check README.md
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from qdrant_client import models
import logging
import json
# --- Setup Logging ---
# Configure logging to be more descriptive
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Local models for vector generation
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Global Variables for Models ---
# These will be loaded once when the application starts
dense_model = None
splade_tokenizer = None
splade_model = None
# --- FastAPI Application ---
app = FastAPI(
title="Hybrid Vector Generation API",
description="An API to generate dense and sparse vectors for a given text query.",
version="1.2.0"
)
# --- Pydantic Models for API ---
class QueryRequest(BaseModel):
"""Request model for the API, expecting a single text query."""
query_text: str
class SparseVectorResponse(BaseModel):
"""Response model for the sparse vector."""
indices: list[int]
values: list[float]
class VectorResponse(BaseModel):
"""Final JSON response model containing both vectors."""
dense_vector: list[float]
sparse_vector: SparseVectorResponse
@app.on_event("startup")
async def load_models():
"""
Asynchronous event to load ML models on application startup.
This ensures models are loaded only once.
"""
global dense_model, splade_tokenizer, splade_model
logger.info("Server is starting up... Time to load the ML models.")
logger.info(f"I'll be using the '{DEVICE}' for processing.")
try:
dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
logger.info("YAaay! All models have been loaded successfully.")
except Exception as e:
logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True)
# In a real-world scenario, you might want the app to fail startup if models don't load
raise e
def compute_splade_vector(text: str) -> models.SparseVector:
"""
Computes a SPLADE sparse vector from a given text query.
Args:
text: The input text string.
Returns:
A Qdrant SparseVector object.
"""
tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device
with torch.no_grad():
output = splade_model(**tokens)
logits, attention_mask = output.logits, tokens['attention_mask']
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vec = max_val.squeeze()
indices = vec.nonzero().squeeze().cpu().tolist()
values = vec[indices].cpu().tolist()
# Ensure indices and values are always lists, even for a single-element tensor
if not isinstance(indices, list):
indices = [indices]
values = [values]
return models.SparseVector(indices=indices, values=values)
@app.post("/vectorize", response_model=VectorResponse)
async def vectorize_query(request: QueryRequest):
"""
API endpoint to generate and return dense and sparse vectors for a text query.
Args:
request: A QueryRequest object containing the 'query_text'.
Returns:
A JSON response containing the dense and sparse vectors.
"""
# --- n8n Logging ---
logger.info("=========================================================")
logger.info("A new request just arrived! Let's see what we've got.")
logger.info(f"The incoming search query from n8n is: '{request.query_text}'")
# 1. Generate Dense Vector
logger.info("First, generating the dense vector for semantic meaning...")
dense_query_vector = dense_model.encode(request.query_text).tolist()
logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector))
logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4]))
# 2. Generate Sparse Vector
logger.info("Next up, creating the sparse vector for keyword matching...")
sparse_query_vector = compute_splade_vector(request.query_text)
logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices))
logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4]))
# 3. Construct and return the response
logger.info("Everything looks good. Packaging up the vectors to send back.")
logger.info("-----------------------------------------------------------------")
final_response = VectorResponse(
dense_vector=dense_query_vector,
sparse_vector=SparseVectorResponse(
indices=sparse_query_vector.indices,
values=sparse_query_vector.values
)
)
return final_response
@app.get("/", include_in_schema=False)
async def root():
return {"message": "Vector Generation API is running. -- VERSION 2 --"}