File size: 2,773 Bytes
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Any

from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.chroma import chunk_list
from llama_index.vector_stores.utils import node_to_metadata_dict


class BatchedChromaVectorStore(ChromaVectorStore):
    """Chroma vector store, batching additions to avoid reaching the max batch limit.

    In this vector store, embeddings are stored within a ChromaDB collection.

    During query time, the index uses ChromaDB to query for the top
    k most similar nodes.

    Args:
        chroma_client (from chromadb.api.API):
            API instance
        chroma_collection (chromadb.api.models.Collection.Collection):
            ChromaDB collection instance

    """

    chroma_client: Any | None

    def __init__(
        self,
        chroma_client: Any,
        chroma_collection: Any,
        host: str | None = None,
        port: str | None = None,
        ssl: bool = False,
        headers: dict[str, str] | None = None,
        collection_kwargs: dict[Any, Any] | None = None,
    ) -> None:
        super().__init__(
            chroma_collection=chroma_collection,
            host=host,
            port=port,
            ssl=ssl,
            headers=headers,
            collection_kwargs=collection_kwargs or {},
        )
        self.chroma_client = chroma_client

    def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
        """Add nodes to index, batching the insertion to avoid issues.

        Args:
            nodes: List[BaseNode]: list of nodes with embeddings
            add_kwargs: _
        """
        if not self.chroma_client:
            raise ValueError("Client not initialized")

        if not self._collection:
            raise ValueError("Collection not initialized")

        max_chunk_size = self.chroma_client.max_batch_size
        node_chunks = chunk_list(nodes, max_chunk_size)

        all_ids = []
        for node_chunk in node_chunks:
            embeddings = []
            metadatas = []
            ids = []
            documents = []
            for node in node_chunk:
                embeddings.append(node.get_embedding())
                metadatas.append(
                    node_to_metadata_dict(
                        node, remove_text=True, flat_metadata=self.flat_metadata
                    )
                )
                ids.append(node.node_id)
                documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

            self._collection.add(
                embeddings=embeddings,
                ids=ids,
                metadatas=metadatas,
                documents=documents,
            )
            all_ids.extend(ids)

        return all_ids