|
|
|
|
|
|
|
import torch |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
from qdrant_client import models |
|
import logging |
|
import json |
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
|
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query' |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
dense_model = None |
|
splade_tokenizer = None |
|
splade_model = None |
|
|
|
|
|
app = FastAPI( |
|
title="Hybrid Vector Generation API", |
|
description="An API to generate dense and sparse vectors for a given text query.", |
|
version="1.2.0" |
|
) |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
"""Request model for the API, expecting a single text query.""" |
|
query_text: str |
|
|
|
class SparseVectorResponse(BaseModel): |
|
"""Response model for the sparse vector.""" |
|
indices: list[int] |
|
values: list[float] |
|
|
|
class VectorResponse(BaseModel): |
|
"""Final JSON response model containing both vectors.""" |
|
dense_vector: list[float] |
|
sparse_vector: SparseVectorResponse |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_models(): |
|
""" |
|
Asynchronous event to load ML models on application startup. |
|
This ensures models are loaded only once. |
|
""" |
|
global dense_model, splade_tokenizer, splade_model |
|
logger.info("Server is starting up... Time to load the ML models.") |
|
logger.info(f"I'll be using the '{DEVICE}' for processing.") |
|
try: |
|
dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE) |
|
splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID) |
|
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE) |
|
logger.info("YAaay! All models have been loaded successfully.") |
|
except Exception as e: |
|
logger.critical(f"Oh no, a critical error occurred while loading models: {e}", exc_info=True) |
|
|
|
raise e |
|
|
|
def compute_splade_vector(text: str) -> models.SparseVector: |
|
""" |
|
Computes a SPLADE sparse vector from a given text query. |
|
|
|
Args: |
|
text: The input text string. |
|
Returns: |
|
A Qdrant SparseVector object. |
|
""" |
|
tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
tokens = {key: val.to(DEVICE) for key, val in tokens.items()} |
|
|
|
with torch.no_grad(): |
|
output = splade_model(**tokens) |
|
|
|
logits, attention_mask = output.logits, tokens['attention_mask'] |
|
relu_log = torch.log(1 + torch.relu(logits)) |
|
weighted_log = relu_log * attention_mask.unsqueeze(-1) |
|
max_val, _ = torch.max(weighted_log, dim=1) |
|
vec = max_val.squeeze() |
|
|
|
indices = vec.nonzero().squeeze().cpu().tolist() |
|
values = vec[indices].cpu().tolist() |
|
|
|
|
|
if not isinstance(indices, list): |
|
indices = [indices] |
|
values = [values] |
|
|
|
return models.SparseVector(indices=indices, values=values) |
|
|
|
|
|
@app.post("/vectorize", response_model=VectorResponse) |
|
async def vectorize_query(request: QueryRequest): |
|
""" |
|
API endpoint to generate and return dense and sparse vectors for a text query. |
|
|
|
Args: |
|
request: A QueryRequest object containing the 'query_text'. |
|
|
|
Returns: |
|
A JSON response containing the dense and sparse vectors. |
|
""" |
|
|
|
logger.info("=========================================================") |
|
logger.info("A new request just arrived! Let's see what we've got.") |
|
logger.info(f"The incoming search query from n8n is: '{request.query_text}'") |
|
|
|
|
|
logger.info("First, generating the dense vector for semantic meaning...") |
|
dense_query_vector = dense_model.encode(request.query_text).tolist() |
|
logger.info("Done with the dense vector. It has %d dimensions.", len(dense_query_vector)) |
|
logger.info("Here's a small sample of the dense vector: %s...", str(dense_query_vector[:4])) |
|
|
|
|
|
logger.info("Next up, creating the sparse vector for keyword matching...") |
|
sparse_query_vector = compute_splade_vector(request.query_text) |
|
logger.info("Sparse vector is ready. It contains %d important terms.", len(sparse_query_vector.indices)) |
|
logger.info("Here's a sample of the sparse vector indices: %s...", str(sparse_query_vector.indices[:4])) |
|
|
|
|
|
logger.info("Everything looks good. Packaging up the vectors to send back.") |
|
logger.info("-----------------------------------------------------------------") |
|
|
|
final_response = VectorResponse( |
|
dense_vector=dense_query_vector, |
|
sparse_vector=SparseVectorResponse( |
|
indices=sparse_query_vector.indices, |
|
values=sparse_query_vector.values |
|
) |
|
) |
|
return final_response |
|
|
|
@app.get("/", include_in_schema=False) |
|
async def root(): |
|
return {"message": "Vector Generation API is running. -- VERSION 2 --"} |
|
|