Spaces:
Running
Running
from qdrant_client import QdrantClient, models | |
from qdrant_client.models import PointStruct, PayloadSchemaType | |
from sentence_transformers import SentenceTransformer | |
import uuid | |
import os | |
import logging | |
from typing import List, Dict, Any | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
class VectorStore: | |
def __init__(self, collection_name: str = "ca_documents"): | |
self.collection_name = collection_name | |
# Get Qdrant configuration from environment variables | |
qdrant_url = os.getenv("QDRANT_URL") | |
qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
if not qdrant_url or not qdrant_api_key: | |
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required") | |
# Connect to Qdrant cluster with API key | |
self.client = QdrantClient( | |
url=qdrant_url, | |
api_key=qdrant_api_key, | |
) | |
print("Connected to Qdrant") | |
# Initialize embedding model with offline support | |
self.embedding_model = self._initialize_embedding_model() | |
# Create collection with proper indices | |
self._create_collection_if_not_exists() | |
def _initialize_embedding_model(self): | |
"""Initialize the embedding model with offline support""" | |
try: | |
# Try to load the model normally first | |
print("Attempting to load sentence transformer model...") | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
print("Successfully loaded sentence transformer model") | |
return model | |
except Exception as e: | |
print(f"Failed to load model online: {e}") | |
print("Attempting to load model in offline mode...") | |
try: | |
# Try to load from cache with offline mode | |
import os | |
os.environ['TRANSFORMERS_OFFLINE'] = '1' | |
os.environ['HF_HUB_OFFLINE'] = '1' | |
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None) | |
print("Successfully loaded model in offline mode") | |
return model | |
except Exception as offline_error: | |
print(f"Failed to load model in offline mode: {offline_error}") | |
# Try to find a local cache directory | |
try: | |
import transformers | |
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers") | |
if os.path.exists(cache_dir): | |
print(f"Looking for cached model in: {cache_dir}") | |
# Try to load from specific cache directory | |
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir) | |
print("Successfully loaded model from cache") | |
return model | |
except Exception as cache_error: | |
print(f"Failed to load from cache: {cache_error}") | |
# If all else fails, provide instructions | |
error_msg = """ | |
Failed to initialize sentence transformer model. This is likely due to network connectivity issues. | |
Solutions: | |
1. Check your internet connection | |
2. If behind a corporate firewall, ensure huggingface.co is accessible | |
3. Pre-download the model when you have internet access by running: | |
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')" | |
4. Or manually download the model and place it in your cache directory | |
For now, the application will not work without the embedding model. | |
""" | |
print(error_msg) | |
raise RuntimeError(f"Cannot initialize embedding model: {str(e)}") | |
def _create_collection_if_not_exists(self) -> bool: | |
""" | |
Create collection with proper payload indices if it doesn't exist. | |
Returns: | |
bool: True if collection exists or was created successfully | |
""" | |
try: | |
# Check if collection exists | |
collections = self.client.get_collections() | |
collection_names = [col.name for col in collections.collections] | |
if self.collection_name in collection_names: | |
print(f"Collection '{self.collection_name}' already exists") | |
return True | |
print(f"Creating new collection: {self.collection_name}") | |
# Vector size for all-MiniLM-L6-v2 is 384 | |
vector_size = 1 | |
# Create collection with vector configuration | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
vectors_config=models.VectorParams( | |
size=vector_size, | |
distance=models.Distance.COSINE | |
), | |
hnsw_config=models.HnswConfigDiff( | |
payload_m=16, | |
m=0, | |
), | |
) | |
# Create payload indices | |
payload_indices = { | |
"document_id": PayloadSchemaType.KEYWORD, | |
"content": PayloadSchemaType.TEXT | |
} | |
for field_name, schema_type in payload_indices.items(): | |
self.client.create_payload_index( | |
collection_name=self.collection_name, | |
field_name=field_name, | |
field_schema=schema_type | |
) | |
print(f"Successfully created collection: {self.collection_name}") | |
return True | |
except Exception as e: | |
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
print(error_msg) | |
return False | |
def add_document(self, text: str, metadata: Dict = None) -> bool: | |
"""Add a document to the collection""" | |
try: | |
# Generate embedding | |
embedding = self.embedding_model.encode([text])[0] | |
# Generate document ID | |
document_id = str(uuid.uuid4()) | |
# Create payload with indexed fields | |
payload = { | |
"document_id": document_id, # KEYWORD index | |
"content": text, # TEXT index - stores the actual text content | |
} | |
# Add metadata fields if provided | |
if metadata: | |
payload.update(metadata) | |
# Create point | |
point = PointStruct( | |
id=document_id, | |
vector=embedding.tolist(), | |
payload=payload | |
) | |
# Store in Qdrant | |
self.client.upsert( | |
collection_name=self.collection_name, | |
points=[point] | |
) | |
return True | |
except Exception as e: | |
print(f"Error adding document: {e}") | |
return False | |
def search_similar(self, query: str, limit: int = 5) -> List[Dict]: | |
"""Search for similar documents""" | |
try: | |
# Generate query embedding | |
query_embedding = self.embedding_model.encode([query])[0] | |
# Search in Qdrant | |
results = self.client.search( | |
collection_name=self.collection_name, | |
query_vector=query_embedding.tolist(), | |
limit=limit | |
) | |
# Return results | |
return [ | |
{ | |
"text": hit.payload["content"], # Use content field | |
"document_id": hit.payload.get("document_id"), | |
"score": hit.score, | |
# Include any additional metadata fields | |
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} | |
} | |
for hit in results | |
] | |
except Exception as e: | |
print(f"Error searching: {e}") | |
return [] | |
def search_by_document_id(self, document_id: str) -> Dict: | |
"""Search for a specific document by its ID using the indexed field""" | |
try: | |
# Use scroll to find document by document_id | |
results = self.client.scroll( | |
collection_name=self.collection_name, | |
scroll_filter=models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="document_id", | |
match=models.MatchValue(value=document_id) | |
) | |
] | |
), | |
limit=1 | |
) | |
if results[0]: # results is a tuple (points, next_page_offset) | |
hit = results[0][0] # Get first point | |
return { | |
"text": hit.payload["content"], # Use content field | |
"document_id": hit.payload.get("document_id"), | |
# Include any additional metadata fields | |
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} | |
} | |
else: | |
return None | |
except Exception as e: | |
print(f"Error searching by document ID: {e}") | |
return None | |
def search_by_content(self, content_query: str, limit: int = 5) -> List[Dict]: | |
"""Search for documents by content using the TEXT index""" | |
try: | |
# Use scroll with text search filter | |
results = self.client.scroll( | |
collection_name=self.collection_name, | |
scroll_filter=models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="content", | |
match=models.MatchText(text=content_query) | |
) | |
] | |
), | |
limit=limit | |
) | |
# Return results | |
return [ | |
{ | |
"text": hit.payload["content"], # Use content field | |
"document_id": hit.payload.get("document_id"), | |
# Include any additional metadata fields | |
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]} | |
} | |
for hit in results[0] # results[0] contains the points | |
] | |
except Exception as e: | |
print(f"Error searching by content: {e}") | |
return [] | |
def get_collection_info(self) -> Dict: | |
"""Get information about the collection""" | |
try: | |
collection_info = self.client.get_collection(self.collection_name) | |
return { | |
"name": collection_info.config.name, | |
"vector_size": collection_info.config.params.vectors.size, | |
"distance": collection_info.config.params.vectors.distance, | |
"points_count": collection_info.points_count, | |
"indexed_only": collection_info.config.params.vectors.on_disk | |
} | |
except Exception as e: | |
print(f"Error getting collection info: {e}") | |
return {} |