File size: 11,827 Bytes
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct, PayloadSchemaType
from sentence_transformers import SentenceTransformer
import uuid
import os
import logging
from typing import List, Dict, Any
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configure logging
logger = logging.getLogger(__name__)

class VectorStore:
    def __init__(self, collection_name: str = "ca_documents"):
        self.collection_name = collection_name
        
        # Get Qdrant configuration from environment variables
        qdrant_url = os.getenv("QDRANT_URL")
        qdrant_api_key = os.getenv("QDRANT_API_KEY")
        
        if not qdrant_url or not qdrant_api_key:
            raise ValueError("QDRANT_URL and QDRANT_API_KEY environment variables are required")
        
        # Connect to Qdrant cluster with API key
        self.client = QdrantClient(
            url=qdrant_url,
            api_key=qdrant_api_key,
        )
        print("Connected to Qdrant")
        
        # Initialize embedding model with offline support
        self.embedding_model = self._initialize_embedding_model()
        
        # Create collection with proper indices
        self._create_collection_if_not_exists()
    
    def _initialize_embedding_model(self):
        """Initialize the embedding model with offline support"""
        try:
            # Try to load the model normally first
            print("Attempting to load sentence transformer model...")
            model = SentenceTransformer("all-MiniLM-L6-v2")
            print("Successfully loaded sentence transformer model")
            return model
            
        except Exception as e:
            print(f"Failed to load model online: {e}")
            print("Attempting to load model in offline mode...")
            
            try:
                # Try to load from cache with offline mode
                import os
                os.environ['TRANSFORMERS_OFFLINE'] = '1'
                os.environ['HF_HUB_OFFLINE'] = '1'
                
                model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=None)
                print("Successfully loaded model in offline mode")
                return model
                
            except Exception as offline_error:
                print(f"Failed to load model in offline mode: {offline_error}")
                
                # Try to find a local cache directory
                try:
                    import transformers
                    cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "transformers")
                    if os.path.exists(cache_dir):
                        print(f"Looking for cached model in: {cache_dir}")
                        
                        # Try to load from specific cache directory
                        model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
                        print("Successfully loaded model from cache")
                        return model
                        
                except Exception as cache_error:
                    print(f"Failed to load from cache: {cache_error}")
                
                # If all else fails, provide instructions
                error_msg = """
                Failed to initialize sentence transformer model. This is likely due to network connectivity issues.
                
                Solutions:
                1. Check your internet connection
                2. If behind a corporate firewall, ensure huggingface.co is accessible
                3. Pre-download the model when you have internet access by running:
                   python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
                4. Or manually download the model and place it in your cache directory
                
                For now, the application will not work without the embedding model.
                """
                
                print(error_msg)
                raise RuntimeError(f"Cannot initialize embedding model: {str(e)}")
    
    def _create_collection_if_not_exists(self) -> bool:
        """
        Create collection with proper payload indices if it doesn't exist.
        
        Returns:
            bool: True if collection exists or was created successfully
        """
        try:
            # Check if collection exists
            collections = self.client.get_collections()
            collection_names = [col.name for col in collections.collections]
            
            if self.collection_name in collection_names:
                print(f"Collection '{self.collection_name}' already exists")
                return True
            
            print(f"Creating new collection: {self.collection_name}")
            
            # Vector size for all-MiniLM-L6-v2 is 384
            vector_size = 1
            
            # Create collection with vector configuration
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=models.VectorParams(
                    size=vector_size,
                    distance=models.Distance.COSINE
                ),
                hnsw_config=models.HnswConfigDiff(
                    payload_m=16,
                    m=0,
                ),
            )
            
            # Create payload indices
            payload_indices = {
                "document_id": PayloadSchemaType.KEYWORD,
                "content": PayloadSchemaType.TEXT
            }
            
            for field_name, schema_type in payload_indices.items():
                self.client.create_payload_index(
                    collection_name=self.collection_name,
                    field_name=field_name,
                    field_schema=schema_type
                )
            
            print(f"Successfully created collection: {self.collection_name}")
            return True
            
        except Exception as e:
            error_msg = f"Failed to create collection {self.collection_name}: {str(e)}"
            logger.error(error_msg, exc_info=True)
            print(error_msg)
            return False
    
    def add_document(self, text: str, metadata: Dict = None) -> bool:
        """Add a document to the collection"""
        try:
            # Generate embedding
            embedding = self.embedding_model.encode([text])[0]
            
            # Generate document ID
            document_id = str(uuid.uuid4())
            
            # Create payload with indexed fields
            payload = {
                "document_id": document_id,  # KEYWORD index
                "content": text,             # TEXT index - stores the actual text content
            }
            
            # Add metadata fields if provided
            if metadata:
                payload.update(metadata)
            
            # Create point
            point = PointStruct(
                id=document_id,
                vector=embedding.tolist(),
                payload=payload
            )
            
            # Store in Qdrant
            self.client.upsert(
                collection_name=self.collection_name,
                points=[point]
            )
            
            return True
        except Exception as e:
            print(f"Error adding document: {e}")
            return False
    
    def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
        """Search for similar documents"""
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode([query])[0]
            
            # Search in Qdrant
            results = self.client.search(
                collection_name=self.collection_name,
                query_vector=query_embedding.tolist(),
                limit=limit
            )
            
            # Return results
            return [
                {
                    "text": hit.payload["content"],  # Use content field
                    "document_id": hit.payload.get("document_id"),
                    "score": hit.score,
                    # Include any additional metadata fields
                    **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
                }
                for hit in results
            ]
            
        except Exception as e:
            print(f"Error searching: {e}")
            return []
    
    def search_by_document_id(self, document_id: str) -> Dict:
        """Search for a specific document by its ID using the indexed field"""
        try:
            # Use scroll to find document by document_id
            results = self.client.scroll(
                collection_name=self.collection_name,
                scroll_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="document_id",
                            match=models.MatchValue(value=document_id)
                        )
                    ]
                ),
                limit=1
            )
            
            if results[0]:  # results is a tuple (points, next_page_offset)
                hit = results[0][0]  # Get first point
                return {
                    "text": hit.payload["content"],  # Use content field
                    "document_id": hit.payload.get("document_id"),
                    # Include any additional metadata fields
                    **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
                }
            else:
                return None
                
        except Exception as e:
            print(f"Error searching by document ID: {e}")
            return None
    
    def search_by_content(self, content_query: str, limit: int = 5) -> List[Dict]:
        """Search for documents by content using the TEXT index"""
        try:
            # Use scroll with text search filter
            results = self.client.scroll(
                collection_name=self.collection_name,
                scroll_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="content",
                            match=models.MatchText(text=content_query)
                        )
                    ]
                ),
                limit=limit
            )
            
            # Return results
            return [
                {
                    "text": hit.payload["content"],  # Use content field
                    "document_id": hit.payload.get("document_id"),
                    # Include any additional metadata fields
                    **{k: v for k, v in hit.payload.items() if k not in ["content", "document_id"]}
                }
                for hit in results[0]  # results[0] contains the points
            ]
            
        except Exception as e:
            print(f"Error searching by content: {e}")
            return []
    
    def get_collection_info(self) -> Dict:
        """Get information about the collection"""
        try:
            collection_info = self.client.get_collection(self.collection_name)
            return {
                "name": collection_info.config.name,
                "vector_size": collection_info.config.params.vectors.size,
                "distance": collection_info.config.params.vectors.distance,
                "points_count": collection_info.points_count,
                "indexed_only": collection_info.config.params.vectors.on_disk
            }
        except Exception as e:
            print(f"Error getting collection info: {e}")
            return {}