Spaces:
Sleeping
Sleeping
from typing import Any | |
from llama_index.schema import BaseNode, MetadataMode | |
from llama_index.vector_stores import ChromaVectorStore | |
from llama_index.vector_stores.chroma import chunk_list | |
from llama_index.vector_stores.utils import node_to_metadata_dict | |
class BatchedChromaVectorStore(ChromaVectorStore): | |
"""Chroma vector store, batching additions to avoid reaching the max batch limit. | |
In this vector store, embeddings are stored within a ChromaDB collection. | |
During query time, the index uses ChromaDB to query for the top | |
k most similar nodes. | |
Args: | |
chroma_client (from chromadb.api.API): | |
API instance | |
chroma_collection (chromadb.api.models.Collection.Collection): | |
ChromaDB collection instance | |
""" | |
chroma_client: Any | None | |
def __init__( | |
self, | |
chroma_client: Any, | |
chroma_collection: Any, | |
host: str | None = None, | |
port: str | None = None, | |
ssl: bool = False, | |
headers: dict[str, str] | None = None, | |
collection_kwargs: dict[Any, Any] | None = None, | |
) -> None: | |
super().__init__( | |
chroma_collection=chroma_collection, | |
host=host, | |
port=port, | |
ssl=ssl, | |
headers=headers, | |
collection_kwargs=collection_kwargs or {}, | |
) | |
self.chroma_client = chroma_client | |
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]: | |
"""Add nodes to index, batching the insertion to avoid issues. | |
Args: | |
nodes: List[BaseNode]: list of nodes with embeddings | |
add_kwargs: _ | |
""" | |
if not self.chroma_client: | |
raise ValueError("Client not initialized") | |
if not self._collection: | |
raise ValueError("Collection not initialized") | |
max_chunk_size = self.chroma_client.max_batch_size | |
node_chunks = chunk_list(nodes, max_chunk_size) | |
all_ids = [] | |
for node_chunk in node_chunks: | |
embeddings = [] | |
metadatas = [] | |
ids = [] | |
documents = [] | |
for node in node_chunk: | |
embeddings.append(node.get_embedding()) | |
metadatas.append( | |
node_to_metadata_dict( | |
node, remove_text=True, flat_metadata=self.flat_metadata | |
) | |
) | |
ids.append(node.node_id) | |
documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) | |
self._collection.add( | |
embeddings=embeddings, | |
ids=ids, | |
metadatas=metadatas, | |
documents=documents, | |
) | |
all_ids.extend(ids) | |
return all_ids | |