File size: 4,356 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Compatibility layer for legacy VectorTable (used by LangChain integration
(as of August 2023).

Note: This is to be replaced by direct usage of the table-class-hierarchy classes.
"""

from typing import Any, Dict, List, Optional
from warnings import warn

from cassandra.cluster import ResponseFuture

from cassio.table.table_types import RowType
from cassio.table.tables import MetadataVectorCassandraTable

new_columns_to_legacy = {
    "row_id": "document_id",
    "body_blob": "document",
    "vector": "embedding_vector",
}
legacy_columns_to_new = {v: k for k, v in new_columns_to_legacy.items()}


class VectorTable:
    """
    This class is a rewriting of the VectorTable created for use in LangChain
    integration, this time relying on the class-table-hierarchy (cassio.table.*).

    It mostly provides a translation layer between parameters and key names,
    using a metadata+vector table class internally.

    Additional kwargs, for use in this new table class, are passed as they are
    in order to enable their usage already before adapting the LangChain
    integration code.
    """

    DEPRECATION_MESSAGE = (
        "Class `VectorTable` is a legacy construct and "
        "will be deprecated in future versions of CassIO."
    )

    def __init__(self, *pargs: Any, **kwargs: Any):
        #
        warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
        #
        if "embedding_dimension" in kwargs:
            vector_dimension = kwargs["embedding_dimension"]
            new_kwargs = {
                **{
                    k: v
                    for k, v in kwargs.items()
                    if k != "embedding_dimension"
                    # let's get rid of the infamous 'auto_id' here:
                    if k != "auto_id"
                },
                **{"vector_dimension": vector_dimension},
            }
        else:
            new_kwargs = kwargs
        # this legacy VectorTable will have everything indexed for search:
        md_kwargs = {
            **{"metadata_indexing": "all"},
            **new_kwargs,
        }
        #
        self.table = MetadataVectorCassandraTable(*pargs, **md_kwargs)

    def search(
        self,
        embedding_vector: List[float],
        top_k: int,
        metric: str = "cos",
        metric_threshold: Optional[float] = None,
        **kwargs: Any,
    ) -> List[RowType]:
        # get rows by ANN
        enriched_hits = self.table.metric_ann_search(
            vector=embedding_vector,
            n=top_k,
            metric=metric,
            metric_threshold=metric_threshold,
            **kwargs,
        )
        #
        return [self._make_dict_legacy(rich_hit) for rich_hit in enriched_hits]

    def put(
        self,
        document: str,
        embedding_vector: List[float],
        document_id: Any,
        metadata: Optional[Dict[str, Any]] = None,
        ttl_seconds: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        self.table.put(
            row_id=document_id,
            body_blob=document,
            vector=embedding_vector,
            metadata=metadata or {},
            ttl_seconds=ttl_seconds,
            **kwargs,
        )

    def put_async(
        self,
        document: str,
        embedding_vector: List[float],
        document_id: Any,
        metadata: Optional[Dict[str, Any]] = None,
        ttl_seconds: Optional[int] = None,
        **kwargs: Any,
    ) -> ResponseFuture:
        return self.table.put_async(
            row_id=document_id,
            body_blob=document,
            vector=embedding_vector,
            metadata=metadata or {},
            ttl_seconds=ttl_seconds,
            **kwargs,
        )

    def get(self, document_id: Any, **kwargs: Any) -> Optional[RowType]:
        row_or_none = self.table.get(row_id=document_id, **kwargs)
        if row_or_none:
            return self._make_dict_legacy(row_or_none)
        else:
            return row_or_none

    def delete(self, document_id: Any, **kwargs: Any) -> None:
        self.table.delete(row_id=document_id, **kwargs)
        return None

    def clear(self) -> None:
        self.table.clear()
        return None

    @staticmethod
    def _make_dict_legacy(new_dict: RowType) -> RowType:
        return {new_columns_to_legacy.get(k, k): v for k, v in new_dict.items()}