Hussam commited on
Commit
06d7b2d
·
1 Parent(s): 93f6882

revised mongo_db index creation, storing, and similarity search using the new Chunk model

Browse files
src/ctp_slack_bot/db/mongo_db.py CHANGED
@@ -83,16 +83,61 @@ class MongoDB(BaseModel):
83
  return False
84
 
85
  async def get_collection(self: Self, name: str) -> Any:
86
- """Get a collection by name with validation."""
 
 
 
87
  if not await self.ping():
88
  raise ConnectionError("MongoDB connection is not available")
 
 
 
 
 
 
 
 
89
  return self.db[name]
90
 
91
- async def create_indexes(self: Self, collection_name: str, indexes: list) -> None:
92
- """Create indexes on a collection."""
 
 
 
 
93
  collection = await self.get_collection(collection_name)
94
- await collection.create_indexes(indexes)
95
- logger.info("Created indexes for collection {}: {}", collection_name, indexes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  async def close(self: Self) -> None:
98
  """Close MongoDB connection."""
 
83
  return False
84
 
85
  async def get_collection(self: Self, name: str) -> Any:
86
+ """
87
+ Get a collection by name with validation.
88
+ Creates the collection if it doesn't exist.
89
+ """
90
  if not await self.ping():
91
  raise ConnectionError("MongoDB connection is not available")
92
+
93
+ # Get all collection names to check if this one exists
94
+ collection_names = await self.db.list_collection_names()
95
+ if name not in collection_names:
96
+ logger.info(f"Collection {name} does not exist. Creating it.")
97
+ # Create the collection
98
+ await self.db.create_collection(name)
99
+
100
  return self.db[name]
101
 
102
+ async def create_indexes(self: Self, collection_name: str, indexes: list = None) -> None:
103
+ """
104
+ Create indexes on a collection.
105
+ If no indexes provided and collection needs vector search capability,
106
+ creates a vector search index using config settings.
107
+ """
108
  collection = await self.get_collection(collection_name)
109
+
110
+ if indexes:
111
+ await collection.create_indexes(indexes)
112
+ logger.info("Created custom indexes for collection {}: {}", collection_name, indexes)
113
+
114
+ else: # Create vector search index using settings from config
115
+ try:
116
+ # Create the vector search index with the proper MongoDB format
117
+ vector_search_index = {
118
+ "mappings": {
119
+ "dynamic": True,
120
+ "fields": {
121
+ "embedding": {
122
+ "type": "knnVector",
123
+ "dimensions": self.settings.VECTOR_DIMENSION,
124
+ "similarity": "cosine"
125
+ }
126
+ }
127
+ }
128
+ }
129
+
130
+ # Using createSearchIndex command which is the proper way to create vector search indexes
131
+ await self.db.command({
132
+ "createSearchIndex": collection_name,
133
+ "name": f"{collection_name}_vector_index",
134
+ "definition": vector_search_index
135
+ })
136
+
137
+ logger.info("Created vector search index for collection {}", collection_name)
138
+ except Exception as e:
139
+ logger.error("Failed to create vector index: {}", e)
140
+ raise
141
 
142
  async def close(self: Self) -> None:
143
  """Close MongoDB connection."""
src/ctp_slack_bot/services/context_retrieval_service.py CHANGED
@@ -22,7 +22,7 @@ class ContextRetrievalService(BaseModel):
22
  logger.debug("Created {}", self.__class__.__name__)
23
  return self
24
 
25
- async def get_context(self, message: SlackMessage) -> Sequence[Chunk]:
26
  """
27
  Retrieve relevant context for a given SlackMessage by vectorizing the message and
28
  querying the vectorstore.
 
22
  logger.debug("Created {}", self.__class__.__name__)
23
  return self
24
 
25
+ async def get_context(self: Self, message: SlackMessage) -> Sequence[Chunk]:
26
  """
27
  Retrieve relevant context for a given SlackMessage by vectorizing the message and
28
  querying the vectorstore.
src/ctp_slack_bot/services/vector_database_service.py CHANGED
@@ -4,13 +4,12 @@ from typing import Any, Collection, Dict, List, Optional, Self, Sequence
4
 
5
  from ctp_slack_bot.core import Settings
6
  from ctp_slack_bot.db import MongoDB
7
- from ctp_slack_bot.models import Chunk, Content, VectorizedChunk, VectorQuery
8
 
9
  class VectorDatabaseService(BaseModel): # TODO: this should not rely specifically on MongoDB.
10
  """
11
  Service for storing and retrieving vector embeddings from MongoDB.
12
  """
13
-
14
  settings: Settings
15
  mongo_db: MongoDB
16
 
@@ -18,69 +17,82 @@ class VectorDatabaseService(BaseModel): # TODO: this should not rely specificall
18
  def post_init(self: Self) -> Self:
19
  logger.debug("Created {}", self.__class__.__name__)
20
  return self
21
-
22
- # Should not allow initialization calls to bubble up all the way to the surface ― sequester in `post_init` or the class on which it depends.
23
- # async def initialize(self) -> None:
24
- # """
25
- # Initialize the database connection.
26
- # """
27
- # await self.mongo_db.initialize()
28
 
29
  # TODO: Weight cost of going all async.
30
- async def store(self, chunks: Collection[VectorizedChunk]) -> None:
31
  """
32
- Store text and its embedding vector in the database.
33
 
34
  Args:
35
- text: The text content to store
36
- embedding: The vector embedding of the text
37
- metadata: Additional metadata about the text (source, timestamp, etc.)
38
 
39
- Returns:
40
- str: The ID of the stored document
41
  """
42
- if not self.mongo_db.initialized:
43
- await self.mongo_db.initialize()
 
44
 
45
  try:
46
- # Create document to store
47
- document = {
48
- "text": text,
49
- "embedding": embedding,
50
- "metadata": metadata
51
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Insert into collection
54
- result = await self.mongo_db.vector_collection.insert_one(document)
55
- logger.debug(f"Stored document with ID: {result.inserted_id}")
56
 
57
- return str(result.inserted_id)
58
  except Exception as e:
59
- logger.error(f"Error storing embedding: {str(e)}")
60
  raise
61
 
62
- async def search_by_similarity(self, query: VectorQuery) -> Sequence[Chunk]:
 
 
 
 
 
 
 
 
 
63
  """
64
  Query the vector database for similar documents.
65
 
66
  Args:
67
  query: VectorQuery object with search parameters
68
- query_embedding: The vector embedding of the query text
69
-
70
  Returns:
71
- List[RetreivedContext]: List of similar documents with similarity scores
72
  """
73
- if not self.mongo_db.initialized:
74
- await self.mongo_db.initialize()
75
-
76
  try:
 
 
 
77
  # Build aggregation pipeline for vector search
78
  pipeline = [
79
  {
80
  "$search": {
81
- "index": "vector_index",
82
  "knnBeta": {
83
- "vector": query_embedding,
84
  "path": "embedding",
85
  "k": query.k
86
  }
@@ -88,10 +100,11 @@ class VectorDatabaseService(BaseModel): # TODO: this should not rely specificall
88
  },
89
  {
90
  "$project": {
91
- "_id": 0,
92
  "text": 1,
93
  "metadata": 1,
94
- "score": {"$meta": "searchScore"}
 
 
95
  }
96
  }
97
  ]
@@ -101,31 +114,33 @@ class VectorDatabaseService(BaseModel): # TODO: this should not rely specificall
101
  metadata_filter = {f"metadata.{k}": v for k, v in query.filter_metadata.items()}
102
  pipeline.insert(1, {"$match": metadata_filter})
103
 
 
 
 
 
 
 
 
 
104
  # Execute the pipeline
105
- results = await self.mongo_db.vector_collection.aggregate(pipeline).to_list(length=query.k)
106
 
107
- # Convert to RetreivedContext objects directly
108
- context_results = []
109
  for result in results:
110
- # Normalize score to [0,1] range
111
- normalized_score = result.get("score", 0)
112
-
113
- # Skip if below threshold
114
- if normalized_score < query.score_threshold:
115
- continue
116
-
117
- context_results.append(
118
- Content(
119
- contextual_text=result["text"],
120
- metadata_source=result["metadata"].get("source", "unknown"),
121
- similarity_score=normalized_score,
122
- said_by=result["metadata"].get("speaker", None),
123
- in_reation_to_question=result["metadata"].get("related_question", None)
124
- )
125
  )
 
126
 
127
- logger.debug(f"Found {len(context_results)} similar documents")
128
- return context_results
129
 
130
  except Exception as e:
131
  logger.error(f"Error in similarity search: {str(e)}")
 
4
 
5
  from ctp_slack_bot.core import Settings
6
  from ctp_slack_bot.db import MongoDB
7
+ from ctp_slack_bot.models import Chunk, VectorizedChunk, VectorQuery
8
 
9
  class VectorDatabaseService(BaseModel): # TODO: this should not rely specifically on MongoDB.
10
  """
11
  Service for storing and retrieving vector embeddings from MongoDB.
12
  """
 
13
  settings: Settings
14
  mongo_db: MongoDB
15
 
 
17
  def post_init(self: Self) -> Self:
18
  logger.debug("Created {}", self.__class__.__name__)
19
  return self
 
 
 
 
 
 
 
20
 
21
  # TODO: Weight cost of going all async.
22
+ async def store(self: Self, chunks: Collection[VectorizedChunk]) -> None:
23
  """
24
+ Stores vectorized chunks and their embedding vectors in the database.
25
 
26
  Args:
27
+ chunks: Collection of VectorizedChunk objects to store
 
 
28
 
29
+ Returns: None
 
30
  """
31
+ if not chunks:
32
+ logger.debug("No chunks to store")
33
+ return
34
 
35
  try:
36
+ # Get the vector collection - this will create it if it doesn't exist
37
+ vector_collection = await self.mongo_db.get_collection("vectors")
38
+
39
+ # Ensure vector search index exists
40
+ await self.mongo_db.create_indexes("vectors")
41
+
42
+ # Create documents to store, ensuring compatibility with BSON
43
+ documents = []
44
+ for chunk in chunks:
45
+ # Convert embedding to standard list format (important for BSON compatibility)
46
+ embedding = list(chunk.embedding) if not isinstance(chunk.embedding, list) else chunk.embedding
47
+
48
+ # Build document with proper structure
49
+ document = {
50
+ "text": chunk.text,
51
+ "embedding": embedding,
52
+ "metadata": chunk.metadata,
53
+ "parent_id": chunk.parent_id,
54
+ "chunk_id": chunk.chunk_id
55
+ }
56
+ documents.append(document)
57
 
58
+ # Insert into collection as a batch
59
+ result = await vector_collection.insert_many(documents)
60
+ logger.info(f"Stored {len(result.inserted_ids)} vector chunks in database")
61
 
 
62
  except Exception as e:
63
+ logger.error(f"Error storing vector embeddings: {str(e)}")
64
  raise
65
 
66
+ async def content_exists(self: Self, key: str)-> bool: # TODO: implement this.
67
+ """
68
+ Check if content exists in the database.
69
+
70
+ Args:
71
+ key: The key to check for content existence
72
+ """
73
+ pass
74
+
75
+ async def search_by_similarity(self: Self, query: VectorQuery) -> Sequence[Chunk]:
76
  """
77
  Query the vector database for similar documents.
78
 
79
  Args:
80
  query: VectorQuery object with search parameters
81
+
 
82
  Returns:
83
+ Sequence[Chunk]: List of similar chunks
84
  """
 
 
 
85
  try:
86
+ # Get the vector collection
87
+ vector_collection = await self.mongo_db.get_collection("vectors")
88
+
89
  # Build aggregation pipeline for vector search
90
  pipeline = [
91
  {
92
  "$search": {
93
+ "index": "vectors_vector_index",
94
  "knnBeta": {
95
+ "vector": list(query.query_embeddings),
96
  "path": "embedding",
97
  "k": query.k
98
  }
 
100
  },
101
  {
102
  "$project": {
 
103
  "text": 1,
104
  "metadata": 1,
105
+ "parent_id": 1,
106
+ "chunk_id": 1,
107
+ "score": { "$meta": "searchScore" }
108
  }
109
  }
110
  ]
 
114
  metadata_filter = {f"metadata.{k}": v for k, v in query.filter_metadata.items()}
115
  pipeline.insert(1, {"$match": metadata_filter})
116
 
117
+ # Add score threshold filter
118
+ if query.score_threshold > 0:
119
+ pipeline.append({
120
+ "$match": {
121
+ "score": { "$gte": query.score_threshold }
122
+ }
123
+ })
124
+
125
  # Execute the pipeline
126
+ results = await vector_collection.aggregate(pipeline).to_list(length=query.k)
127
 
128
+ # Convert results to Chunk objects
129
+ chunks = []
130
  for result in results:
131
+ chunk = Chunk(
132
+ text=result["text"],
133
+ parent_id=result["parent_id"],
134
+ chunk_id=result["chunk_id"],
135
+ metadata={
136
+ **result["metadata"],
137
+ "similarity_score": result.get("score", 0)
138
+ }
 
 
 
 
 
 
 
139
  )
140
+ chunks.append(chunk)
141
 
142
+ logger.info(f"Found {len(chunks)} similar chunks with similarity search")
143
+ return chunks
144
 
145
  except Exception as e:
146
  logger.error(f"Error in similarity search: {str(e)}")