CA-Foundation / backend /vector_store.py
“vinit5112”
post changes
1edfa40
from qdrant_client import QdrantClient, models, grpc
from qdrant_client.http.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv
import time
import asyncio
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self):
self.collection_name = "ca-documents"
# Get Qdrant configuration from environment variables
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
# Connect to Qdrant cluster with API key
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
prefer_grpc=True,
)
print("Connected to Qdrant")
# Initialize embedding model with offline support
self.embedding_model = self._initialize_embedding_model()
async def initialize(self):
"""Asynchronous initialization to be called after object creation."""
await self._ensure_collection_exists()
def _initialize_embedding_model(self):
"""Initialize the embedding model from a local directory"""
try:
print("Loading sentence transformer model from local path...")
# Resolve local path to model directory
current_dir = os.path.dirname(os.path.abspath(__file__))
local_model_path = os.path.join(current_dir, "..", "model", "all-MiniLM-L6-v2")
model = SentenceTransformer(local_model_path)
print("Successfully loaded local sentence transformer model")
return model
except Exception as e:
print(f"Failed to load local model: {e}")
raise RuntimeError("Failed to initialize embedding model from local path")
async def _collection_exists_and_accessible(self) -> bool:
"""
Check if collection exists and is accessible by trying to get its info.
Returns:
bool: True if collection exists and is accessible
"""
try:
# Try to get collection info - this is more reliable than just listing collections
collection_info = await self.client.get_collection(self.collection_name)
print(f"Collection '{self.collection_name}' exists and is accessible")
return True
except Exception as e:
print(f"Collection '{self.collection_name}' is not accessible: {e}")
return False
async def _create_collection(self) -> bool:
"""
Create the collection with proper configuration.
Returns:
bool: True if collection was created successfully or already exists
"""
try:
print(f"Creating new collection: {self.collection_name}")
# Vector size for all-MiniLM-L6-v2 is 384
vector_size = 384
# Create collection with vector configuration
await self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Wait a moment for collection to be fully created
await asyncio.sleep(1)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
try:
await self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=schema_type
)
except Exception as idx_error:
print(f"Warning: Failed to create index for {field_name}: {idx_error}")
print(f"Successfully created collection: {self.collection_name}")
return True
except Exception as e:
# Check if the error is because collection already exists
if "already exists" in str(e).lower() or "ALREADY_EXISTS" in str(e):
print(f"Collection '{self.collection_name}' already exists, using existing collection")
return True
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
async def _ensure_collection_exists(self) -> bool:
"""
Ensure collection exists and is accessible, create if necessary.
Returns:
bool: True if collection exists or was created successfully
"""
try:
# First, check if collection exists and is accessible
if await self._collection_exists_and_accessible():
print(f"Collection '{self.collection_name}' is ready to use")
return True
# If not accessible, try to create it (or verify it exists)
print(f"Collection '{self.collection_name}' not immediately accessible, attempting to create/verify...")
created = await self._create_collection()
# After creation attempt, verify it's accessible
if created and await self._collection_exists_and_accessible():
print(f"Collection '{self.collection_name}' is now ready to use")
return True
elif created:
# Created successfully but not immediately accessible, which is okay
print(f"Collection '{self.collection_name}' created/verified successfully")
return True
else:
return False
except Exception as e:
error_msg = f"Failed to ensure collection exists: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
async def add_document(self, text: str, metadata: Dict = None) -> bool:
"""Add a document to the collection with retry logic"""
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
# Ensure collection exists before adding document
if not await self._collection_exists_and_accessible():
print("Collection not accessible, trying to recreate...")
if not await self._create_collection():
raise Exception("Failed to create collection")
# Generate embedding
embedding = self.embedding_model.encode([text])[0]
# Generate document ID
document_id = str(uuid.uuid4())
# Create payload with indexed fields
payload = {
"document_id": document_id, # KEYWORD index
"content": text, # TEXT index - stores the actual text content
}
# Add metadata fields if provided
if metadata:
payload.update(metadata)
# Create point
point = PointStruct(
id=document_id,
vector=embedding.tolist(),
payload=payload
)
# Store in Qdrant
result = await self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
# Check if upsert was successful
if hasattr(result, 'status') and result.status == 'completed':
return True
elif hasattr(result, 'operation_id'):
return True
else:
print(f"Unexpected upsert result: {result}")
return True # Assume success if no error was raised
except Exception as e:
print(f"Error adding document (attempt {attempt + 1}/{max_retries}): {e}")
if "Not found" in str(e) and "doesn't exist" in str(e):
# Collection doesn't exist, try to recreate
print("Collection not found, attempting to recreate...")
await self._create_collection()
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
print(f"Failed to add document after {max_retries} attempts")
return False
return False
async def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
"""Search for similar documents with error handling"""
try:
# Ensure collection exists before searching
if not await self._collection_exists_and_accessible():
print("Collection not accessible for search")
return []
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Search in Qdrant
results = await self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
"score": hit.score,
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results
]
except Exception as e:
print(f"Error searching: {e}")
return []
async def get_collection_info(self) -> Dict:
"""Get information about the collection"""
try:
collection_info = await self.client.get_collection(self.collection_name)
return {
"name": collection_info.config.name,
"vector_size": collection_info.config.params.vectors.size,
"distance": collection_info.config.params.vectors.distance,
"points_count": collection_info.points_count,
"indexed_only": collection_info.config.params.vectors.on_disk
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {}
async def verify_collection_health(self) -> bool:
"""Verify that the collection is healthy and accessible"""
try:
# Try to get collection info
info = await self.get_collection_info()
if not info:
return False
# Try a simple search to verify functionality
test_results = await self.search_similar("test query", limit=1)
# This should not fail even if no results are found
print(f"Collection health check passed. Points count: {info.get('points_count', 0)}")
return True
except Exception as e:
print(f"Collection health check failed: {e}")
return False