Justin44's picture
Create app.py
2698809 verified
raw
history blame
4.66 kB
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from qdrant_client import models
import logging
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Local models for vector generation
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Global Variables for Models ---
# These will be loaded once when the application starts
dense_model = None
splade_tokenizer = None
splade_model = None
# --- FastAPI Application ---
app = FastAPI(
title="Hybrid Vector Generation API",
description="An API to generate dense and sparse vectors for a given text query.",
version="1.0.0"
)
# --- Pydantic Models for API ---
class QueryRequest(BaseModel):
"""Request model for the API, expecting a single text query."""
query_text: str
class SparseVectorResponse(BaseModel):
"""Response model for the sparse vector."""
indices: list[int]
values: list[float]
class VectorResponse(BaseModel):
"""The final JSON response model containing both vectors."""
dense_vector: list[float]
sparse_vector: SparseVectorResponse
@app.on_event("startup")
async def load_models():
"""
Asynchronous event to load ML models on application startup.
This ensures models are loaded only once.
"""
global dense_model, splade_tokenizer, splade_model
logger.info(f"Loading models onto device: {DEVICE}")
try:
dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
logger.info("Models initialized successfully.")
except Exception as e:
logger.fatal(f"FATAL: Could not initialize models. Error: {e}")
# In a real-world scenario, you might want the app to fail startup if models don't load.
raise e
def compute_splade_vector(text: str) -> models.SparseVector:
"""
Computes a SPLADE sparse vector from a given text query.
Args:
text: The input text string.
Returns:
A Qdrant SparseVector object.
"""
tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device
with torch.no_grad():
output = splade_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vec = max_val.squeeze()
indices = vec.nonzero().squeeze().cpu().tolist()
values = vec[indices].cpu().tolist()
# Ensure indices and values are always lists, even for a single-element tensor
if not isinstance(indices, list):
indices = [indices]
values = [values]
return models.SparseVector(indices=indices, values=values)
@app.post("/vectorize", response_model=VectorResponse)
async def vectorize_query(request: QueryRequest):
"""
API endpoint to generate and return dense and sparse vectors for a text query.
Args:
request: A QueryRequest object containing the 'query_text'.
Returns:
A JSON response containing the dense and sparse vectors.
"""
logger.info(f"Received query for vectorization: '{request.query_text}'")
# 1. Generate Dense Vector
logger.info("Generating dense vector...")
dense_query_vector = dense_model.encode(request.query_text).tolist()
logger.info("Dense vector generated.")
# 2. Generate Sparse Vector
logger.info("Generating sparse vector...")
sparse_query_vector = compute_splade_vector(request.query_text)
logger.info("Sparse vector generated.")
# 3. Construct and return the response
return VectorResponse(
dense_vector=dense_query_vector,
sparse_vector=SparseVectorResponse(
indices=sparse_query_vector.indices,
values=sparse_query_vector.values
)
)
@app.get("/", include_in_schema=False)
async def root():
return {"message": "Vector Generation API is running. POST to /vectorize to get vectors."}