Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| from typing import Any | |
| from llama_index.schema import BaseNode, MetadataMode | |
| from llama_index.vector_stores import ChromaVectorStore | |
| from llama_index.vector_stores.chroma import chunk_list | |
| from llama_index.vector_stores.utils import node_to_metadata_dict | |
| class BatchedChromaVectorStore(ChromaVectorStore): | |
| """Chroma vector store, batching additions to avoid reaching the max batch limit. | |
| In this vector store, embeddings are stored within a ChromaDB collection. | |
| During query time, the index uses ChromaDB to query for the top | |
| k most similar nodes. | |
| Args: | |
| chroma_client (from chromadb.api.API): | |
| API instance | |
| chroma_collection (chromadb.api.models.Collection.Collection): | |
| ChromaDB collection instance | |
| """ | |
| chroma_client: Any | None | |
| def __init__( | |
| self, | |
| chroma_client: Any, | |
| chroma_collection: Any, | |
| host: str | None = None, | |
| port: str | None = None, | |
| ssl: bool = False, | |
| headers: dict[str, str] | None = None, | |
| collection_kwargs: dict[Any, Any] | None = None, | |
| ) -> None: | |
| super().__init__( | |
| chroma_collection=chroma_collection, | |
| host=host, | |
| port=port, | |
| ssl=ssl, | |
| headers=headers, | |
| collection_kwargs=collection_kwargs or {}, | |
| ) | |
| self.chroma_client = chroma_client | |
| def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]: | |
| """Add nodes to index, batching the insertion to avoid issues. | |
| Args: | |
| nodes: List[BaseNode]: list of nodes with embeddings | |
| add_kwargs: _ | |
| """ | |
| if not self.chroma_client: | |
| raise ValueError("Client not initialized") | |
| if not self._collection: | |
| raise ValueError("Collection not initialized") | |
| max_chunk_size = self.chroma_client.max_batch_size | |
| node_chunks = chunk_list(nodes, max_chunk_size) | |
| all_ids = [] | |
| for node_chunk in node_chunks: | |
| embeddings = [] | |
| metadatas = [] | |
| ids = [] | |
| documents = [] | |
| for node in node_chunk: | |
| embeddings.append(node.get_embedding()) | |
| metadatas.append( | |
| node_to_metadata_dict( | |
| node, remove_text=True, flat_metadata=self.flat_metadata | |
| ) | |
| ) | |
| ids.append(node.node_id) | |
| documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) | |
| self._collection.add( | |
| embeddings=embeddings, | |
| ids=ids, | |
| metadatas=metadatas, | |
| documents=documents, | |
| ) | |
| all_ids.extend(ids) | |
| return all_ids | |