Spaces:
Sleeping
Sleeping
File size: 11,827 Bytes
deb090d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
class VectorStore:
def __init__(self, collection_name: str = "ca_documents"):
self.collection_name = collection_name
# Get Qdrant configuration from environment variables
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
# Connect to Qdrant cluster with API key
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
print("Connected to Qdrant")
# Initialize embedding model with offline support
self.embedding_model = self._initialize_embedding_model()
# Create collection with proper indices
self._create_collection_if_not_exists()
def _initialize_embedding_model(self):
"""Initialize the embedding model with offline support"""
try:
# Try to load the model normally first
print("Attempting to load sentence transformer model...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Successfully loaded sentence transformer model")
return model
except Exception as e:
print(f"Failed to load model online: {e}")
print("Attempting to load model in offline mode...")
try:
# Try to load from cache with offline mode
import os
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None)
print("Successfully loaded model in offline mode")
return model
except Exception as offline_error:
print(f"Failed to load model in offline mode: {offline_error}")
# Try to find a local cache directory
try:
import transformers
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers")
if os.path.exists(cache_dir):
print(f"Looking for cached model in: {cache_dir}")
# Try to load from specific cache directory
model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
print("Successfully loaded model from cache")
return model
except Exception as cache_error:
print(f"Failed to load from cache: {cache_error}")
# If all else fails, provide instructions
error_msg = """
Failed to initialize sentence transformer model. This is likely due to network connectivity issues.
Solutions:
1. Check your internet connection
2. If behind a corporate firewall, ensure huggingface.co is accessible
3. Pre-download the model when you have internet access by running:
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
4. Or manually download the model and place it in your cache directory
For now, the application will not work without the embedding model.
"""
print(error_msg)
raise RuntimeError(f"Cannot initialize embedding model: {str(e)}")
def _create_collection_if_not_exists(self) -> bool:
"""
Create collection with proper payload indices if it doesn't exist.
Returns:
bool: True if collection exists or was created successfully
"""
try:
# Check if collection exists
collections = self.client.get_collections()
collection_names = [col.name for col in collections.collections]
if self.collection_name in collection_names:
print(f"Collection '{self.collection_name}' already exists")
return True
print(f"Creating new collection: {self.collection_name}")
# Vector size for all-MiniLM-L6-v2 is 384
vector_size = 1
# Create collection with vector configuration
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
# Create payload indices
payload_indices = {
"document_id": PayloadSchemaType.KEYWORD,
"content": PayloadSchemaType.TEXT
}
for field_name, schema_type in payload_indices.items():
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=schema_type
)
print(f"Successfully created collection: {self.collection_name}")
return True
except Exception as e:
error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
print(error_msg)
return False
def add_document(self, text: str, metadata: Dict = None) -> bool:
"""Add a document to the collection"""
try:
# Generate embedding
embedding = self.embedding_model.encode([text])[0]
# Generate document ID
document_id = str(uuid.uuid4())
# Create payload with indexed fields
payload = {
"document_id": document_id, # KEYWORD index
"content": text, # TEXT index - stores the actual text content
}
# Add metadata fields if provided
if metadata:
payload.update(metadata)
# Create point
point = PointStruct(
id=document_id,
vector=embedding.tolist(),
payload=payload
)
# Store in Qdrant
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
return True
except Exception as e:
print(f"Error adding document: {e}")
return False
def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
"""Search for similar documents"""
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Search in Qdrant
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
"score": hit.score,
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results
]
except Exception as e:
print(f"Error searching: {e}")
return []
def search_by_document_id(self, document_id: str) -> Dict:
"""Search for a specific document by its ID using the indexed field"""
try:
# Use scroll to find document by document_id
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="document_id",
match=models.MatchValue(value=document_id)
)
]
),
limit=1
)
if results[0]: # results is a tuple (points, next_page_offset)
hit = results[0][0] # Get first point
return {
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
else:
return None
except Exception as e:
print(f"Error searching by document ID: {e}")
return None
def search_by_content(self, content_query: str, limit: int = 5) -> List[Dict]:
"""Search for documents by content using the TEXT index"""
try:
# Use scroll with text search filter
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="content",
match=models.MatchText(text=content_query)
)
]
),
limit=limit
)
# Return results
return [
{
"text": hit.payload["content"], # Use content field
"document_id": hit.payload.get("document_id"),
# Include any additional metadata fields
**{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
}
for hit in results[0] # results[0] contains the points
]
except Exception as e:
print(f"Error searching by content: {e}")
return []
def get_collection_info(self) -> Dict:
"""Get information about the collection"""
try:
collection_info = self.client.get_collection(self.collection_name)
return {
"name": collection_info.config.name,
"vector_size": collection_info.config.params.vectors.size,
"distance": collection_info.config.params.vectors.distance,
"points_count": collection_info.points_count,
"indexed_only": collection_info.config.params.vectors.on_disk
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {} |