Spaces:
Sleeping
Sleeping
File size: 4,658 Bytes
2698809 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from qdrant_client import models
import logging
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Local models for vector generation
DENSE_MODEL_ID = 'sentence-transformers/all-MiniLM-L6-v2'
# Use the corresponding QUERY encoder for SPLADE, which is optimized for search queries
SPLADE_QUERY_MODEL_ID = 'naver/efficient-splade-VI-BT-large-query'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Global Variables for Models ---
# These will be loaded once when the application starts
dense_model = None
splade_tokenizer = None
splade_model = None
# --- FastAPI Application ---
app = FastAPI(
title="Hybrid Vector Generation API",
description="An API to generate dense and sparse vectors for a given text query.",
version="1.0.0"
)
# --- Pydantic Models for API ---
class QueryRequest(BaseModel):
"""Request model for the API, expecting a single text query."""
query_text: str
class SparseVectorResponse(BaseModel):
"""Response model for the sparse vector."""
indices: list[int]
values: list[float]
class VectorResponse(BaseModel):
"""The final JSON response model containing both vectors."""
dense_vector: list[float]
sparse_vector: SparseVectorResponse
@app.on_event("startup")
async def load_models():
"""
Asynchronous event to load ML models on application startup.
This ensures models are loaded only once.
"""
global dense_model, splade_tokenizer, splade_model
logger.info(f"Loading models onto device: {DEVICE}")
try:
dense_model = SentenceTransformer(DENSE_MODEL_ID, device=DEVICE)
splade_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_MODEL_ID)
splade_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_MODEL_ID).to(DEVICE)
logger.info("Models initialized successfully.")
except Exception as e:
logger.fatal(f"FATAL: Could not initialize models. Error: {e}")
# In a real-world scenario, you might want the app to fail startup if models don't load.
raise e
def compute_splade_vector(text: str) -> models.SparseVector:
"""
Computes a SPLADE sparse vector from a given text query.
Args:
text: The input text string.
Returns:
A Qdrant SparseVector object.
"""
tokens = splade_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
tokens = {key: val.to(DEVICE) for key, val in tokens.items()} # Move tensors to the correct device
with torch.no_grad():
output = splade_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vec = max_val.squeeze()
indices = vec.nonzero().squeeze().cpu().tolist()
values = vec[indices].cpu().tolist()
# Ensure indices and values are always lists, even for a single-element tensor
if not isinstance(indices, list):
indices = [indices]
values = [values]
return models.SparseVector(indices=indices, values=values)
@app.post("/vectorize", response_model=VectorResponse)
async def vectorize_query(request: QueryRequest):
"""
API endpoint to generate and return dense and sparse vectors for a text query.
Args:
request: A QueryRequest object containing the 'query_text'.
Returns:
A JSON response containing the dense and sparse vectors.
"""
logger.info(f"Received query for vectorization: '{request.query_text}'")
# 1. Generate Dense Vector
logger.info("Generating dense vector...")
dense_query_vector = dense_model.encode(request.query_text).tolist()
logger.info("Dense vector generated.")
# 2. Generate Sparse Vector
logger.info("Generating sparse vector...")
sparse_query_vector = compute_splade_vector(request.query_text)
logger.info("Sparse vector generated.")
# 3. Construct and return the response
return VectorResponse(
dense_vector=dense_query_vector,
sparse_vector=SparseVectorResponse(
indices=sparse_query_vector.indices,
values=sparse_query_vector.values
)
)
@app.get("/", include_in_schema=False)
async def root():
return {"message": "Vector Generation API is running. POST to /vectorize to get vectors."} |