from typing import Any from llama_index.schema import BaseNode, MetadataMode from llama_index.vector_stores import ChromaVectorStore from llama_index.vector_stores.chroma import chunk_list from llama_index.vector_stores.utils import node_to_metadata_dict class BatchedChromaVectorStore(ChromaVectorStore): """Chroma vector store, batching additions to avoid reaching the max batch limit. In this vector store, embeddings are stored within a ChromaDB collection. During query time, the index uses ChromaDB to query for the top k most similar nodes. Args: chroma_client (from chromadb.api.API): API instance chroma_collection (chromadb.api.models.Collection.Collection): ChromaDB collection instance """ chroma_client: Any | None def __init__( self, chroma_client: Any, chroma_collection: Any, host: str | None = None, port: str | None = None, ssl: bool = False, headers: dict[str, str] | None = None, collection_kwargs: dict[Any, Any] | None = None, ) -> None: super().__init__( chroma_collection=chroma_collection, host=host, port=port, ssl=ssl, headers=headers, collection_kwargs=collection_kwargs or {}, ) self.chroma_client = chroma_client def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]: """Add nodes to index, batching the insertion to avoid issues. Args: nodes: List[BaseNode]: list of nodes with embeddings add_kwargs: _ """ if not self.chroma_client: raise ValueError("Client not initialized") if not self._collection: raise ValueError("Collection not initialized") max_chunk_size = self.chroma_client.max_batch_size node_chunks = chunk_list(nodes, max_chunk_size) all_ids = [] for node_chunk in node_chunks: embeddings = [] metadatas = [] ids = [] documents = [] for node in node_chunk: embeddings.append(node.get_embedding()) metadatas.append( node_to_metadata_dict( node, remove_text=True, flat_metadata=self.flat_metadata ) ) ids.append(node.node_id) documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) self._collection.add( embeddings=embeddings, ids=ids, metadatas=metadatas, documents=documents, ) all_ids.extend(ids) return all_ids