CA-Foundation / backend /vector_store.py
“vinit5112”
Add all code
deb090d
raw
history blame
11.8 kB
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self, collection_name: str = "ca_documents"):
self.collection_name = collection_name
# Get Qdrant configuration from environment variables
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
# Connect to Qdrant cluster with API key
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
print("Connected to Qdrant")
# Initialize embedding model with offline support
self.embedding_model = self._initialize_embedding_model()
# Create collection with proper indices
self._create_collection_if_not_exists()
def _initialize_embedding_model(self):
"""Initialize the embedding model with offline support"""
try:
# Try to load the model normally first
print("Attempting to load sentence transformer model...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Successfully loaded sentence transformer model")
return model
except Exception as e:
print(f"Failed to load model online: {e}")
print("Attempting to load model in offline mode...")
try:
# Try to load from cache with offline mode
import os
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None)
print("Successfully loaded model in offline mode")
return model
except Exception as offline_error:
print(f"Failed to load model in offline mode: {offline_error}")
# Try to find a local cache directory
try:
import transformers
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers")
if os.path.exists(cache_dir):
print(f"Looking for cached model in: {cache_dir}")
# Try to load from specific cache directory
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
print("Successfully loaded model from cache")
return model
except Exception as cache_error:
print(f"Failed to load from cache: {cache_error}")
# If all else fails, provide instructions
error_msg = """
Failed to initialize sentence transformer model. This is likely due to network connectivity issues.
Solutions:
1. Check your internet connection
2. If behind a corporate firewall, ensure huggingface.co is accessible
3. Pre-download the model when you have internet access by running:
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
4. Or manually download the model and place it in your cache directory
For now, the application will not work without the embedding model.
"""
print(error_msg)
raise RuntimeError(f"Cannot initialize embedding model: {str(e)}")
def _create_collection_if_not_exists(self) -> bool:
"""
Create collection with proper payload indices if it doesn't exist.
Returns:
bool: True if collection exists or was created successfully
"""
try:
# Check if collection exists
collections = self.client.get_collections()
collection_names = [col.name for col in collections.collections]
if self.collection_name in collection_names:
print(f"Collection '{self.collection_name}' already exists")
return True
print(f"Creating new collection: {self.collection_name}")
# Vector size for all-MiniLM-L6-v2 is 384
vector_size = 1
# Create collection with vector configuration
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {self.collection_name}")
return True
except Exception as e:
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
def add_document(self, text: str, metadata: Dict = None) -> bool:
"""Add a document to the collection"""
try:
# Generate embedding
embedding = self.embedding_model.encode([text])[0]
# Generate document ID
document_id = str(uuid.uuid4())
# Create payload with indexed fields
payload = {
"document_id": document_id, # KEYWORD index
"content": text, # TEXT index - stores the actual text content
}
# Add metadata fields if provided
if metadata:
payload.update(metadata)
# Create point
point = PointStruct(
id=document_id,
vector=embedding.tolist(),
payload=payload
)
# Store in Qdrant
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
return True
except Exception as e:
print(f"Error adding document: {e}")
return False
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
"""Search for similar documents"""
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Search in Qdrant
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
"score": hit.score,
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results
]
except Exception as e:
print(f"Error searching: {e}")
return []
def search_by_document_id(self, document_id: str) -> Dict:
"""Search for a specific document by its ID using the indexed field"""
try:
# Use scroll to find document by document_id
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="document_id",
match=models.MatchValue(value=document_id)
)
]
),
limit=1
)
if results[0]: # results is a tuple (points, next_page_offset)
hit = results[0][0] # Get first point
return {
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
else:
return None
except Exception as e:
print(f"Error searching by document ID: {e}")
return None
def search_by_content(self, content_query: str, limit: int = 5) -> List[Dict]:
"""Search for documents by content using the TEXT index"""
try:
# Use scroll with text search filter
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="content",
match=models.MatchText(text=content_query)
)
]
),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results[0] # results[0] contains the points
]
except Exception as e:
print(f"Error searching by content: {e}")
return []
def get_collection_info(self) -> Dict:
"""Get information about the collection"""
try:
collection_info = self.client.get_collection(self.collection_name)
return {
"name": collection_info.config.name,
"vector_size": collection_info.config.params.vectors.size,
"distance": collection_info.config.params.vectors.distance,
"points_count": collection_info.points_count,
"indexed_only": collection_info.config.params.vectors.on_disk
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {}