from qdrant_client import QdrantClient, models from qdrant_client.models import PointStruct, PayloadSchemaType from sentence_transformers import SentenceTransformer import uuid import os import logging from typing import List, Dict, Any from dotenv import load_dotenv # Load environment variables load_dotenv() # Configure logging logger = logging.getLogger(__name__) class VectorStore: def __init__(self, collection_name: str = "ca_documents"): self.collection_name = collection_name # Get Qdrant configuration from environment variables qdrant_url = os.getenv("QDRANT_URL") qdrant_api_key = os.getenv("QDRANT_API_KEY") if not qdrant_url or not qdrant_api_key: raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required") # Connect to Qdrant cluster with API key self.client = QdrantClient( url=qdrant_url, api_key=qdrant_api_key, ) print("Connected to Qdrant") # Initialize embedding model with offline support self.embedding_model = self._initialize_embedding_model() # Create collection with proper indices self._create_collection_if_not_exists() def _initialize_embedding_model(self): """Initialize the embedding model with offline support""" try: # Try to load the model normally first print("Attempting to load sentence transformer model...") model = SentenceTransformer("all-MiniLM-L6-v2") print("Successfully loaded sentence transformer model") return model except Exception as e: print(f"Failed to load model online: {e}") print("Attempting to load model in offline mode...") try: # Try to load from cache with offline mode import os os.environ['TRANSFORMERS_OFFLINE'] = '1' os.environ['HF_HUB_OFFLINE'] = '1' model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None) print("Successfully loaded model in offline mode") return model except Exception as offline_error: print(f"Failed to load model in offline mode: {offline_error}") # Try to find a local cache directory try: import transformers cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers") if os.path.exists(cache_dir): print(f"Looking for cached model in: {cache_dir}") # Try to load from specific cache directory model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir) print("Successfully loaded model from cache") return model except Exception as cache_error: print(f"Failed to load from cache: {cache_error}") # If all else fails, provide instructions error_msg = """ Failed to initialize sentence transformer model. This is likely due to network connectivity issues. Solutions: 1. Check your internet connection 2. If behind a corporate firewall, ensure huggingface.co is accessible 3. Pre-download the model when you have internet access by running: python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')" 4. Or manually download the model and place it in your cache directory For now, the application will not work without the embedding model. """ print(error_msg) raise RuntimeError(f"Cannot initialize embedding model: {str(e)}") def _create_collection_if_not_exists(self) -> bool: """ Create collection with proper payload indices if it doesn't exist. Returns: bool: True if collection exists or was created successfully """ try: # Check if collection exists collections = self.client.get_collections() collection_names = [col.name for col in collections.collections] if self.collection_name in collection_names: print(f"Collection '{self.collection_name}' already exists") return True print(f"Creating new collection: {self.collection_name}") # Vector size for all-MiniLM-L6-v2 is 384 vector_size = 1 # Create collection with vector configuration self.client.create_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( size=vector_size, distance=models.Distance.COSINE ), hnsw_config=models.HnswConfigDiff( payload_m=16, m=0, ), ) # Create payload indices payload_indices = { "document_id": PayloadSchemaType.KEYWORD, "content": PayloadSchemaType.TEXT } for field_name, schema_type in payload_indices.items(): self.client.create_payload_index( collection_name=self.collection_name, field_name=field_name, field_schema=schema_type ) print(f"Successfully created collection: {self.collection_name}") return True except Exception as e: error_msg = f"Failed to create collection {self.collection_name}: {str(e)}" logger.error(error_msg, exc_info=True) print(error_msg) return False def add_document(self, text: str, metadata: Dict = None) -> bool: """Add a document to the collection""" try: # Generate embedding embedding = self.embedding_model.encode([text])[0] # Generate document ID document_id = str(uuid.uuid4()) # Create payload with indexed fields payload = { "document_id": document_id, # KEYWORD index "content": text, # TEXT index - stores the actual text content } # Add metadata fields if provided if metadata: payload.update(metadata) # Create point point = PointStruct( id=document_id, vector=embedding.tolist(), payload=payload ) # Store in Qdrant self.client.upsert( collection_name=self.collection_name, points=[point] ) return True except Exception as e: print(f"Error adding document: {e}") return False def search_similar(self, query: str, limit: int = 5) -> List[Dict]: """Search for similar documents""" try: # Generate query embedding query_embedding = self.embedding_model.encode([query])[0] # Search in Qdrant results = self.client.search( collection_name=self.collection_name, query_vector=query_embedding.tolist(), limit=limit ) # Return results return [ { "text": hit.payload["content"], # Use content field "document_id": hit.payload.get("document_id"), "score": hit.score, # Include any additional metadata fields **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} } for hit in results ] except Exception as e: print(f"Error searching: {e}") return [] def search_by_document_id(self, document_id: str) -> Dict: """Search for a specific document by its ID using the indexed field""" try: # Use scroll to find document by document_id results = self.client.scroll( collection_name=self.collection_name, scroll_filter=models.Filter( must=[ models.FieldCondition( key="document_id", match=models.MatchValue(value=document_id) ) ] ), limit=1 ) if results[0]: # results is a tuple (points, next_page_offset) hit = results[0][0] # Get first point return { "text": hit.payload["content"], # Use content field "document_id": hit.payload.get("document_id"), # Include any additional metadata fields **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} } else: return None except Exception as e: print(f"Error searching by document ID: {e}") return None def search_by_content(self, content_query: str, limit: int = 5) -> List[Dict]: """Search for documents by content using the TEXT index""" try: # Use scroll with text search filter results = self.client.scroll( collection_name=self.collection_name, scroll_filter=models.Filter( must=[ models.FieldCondition( key="content", match=models.MatchText(text=content_query) ) ] ), limit=limit ) # Return results return [ { "text": hit.payload["content"], # Use content field "document_id": hit.payload.get("document_id"), # Include any additional metadata fields **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} } for hit in results[0] # results[0] contains the points ] except Exception as e: print(f"Error searching by content: {e}") return [] def get_collection_info(self) -> Dict: """Get information about the collection""" try: collection_info = self.client.get_collection(self.collection_name) return { "name": collection_info.config.name, "vector_size": collection_info.config.params.vectors.size, "distance": collection_info.config.params.vectors.distance, "points_count": collection_info.points_count, "indexed_only": collection_info.config.params.vectors.on_disk } except Exception as e: print(f"Error getting collection info: {e}") return {}