Spaces:
Running
Running
File size: 4,356 Bytes
2a0bc63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Compatibility layer for legacy VectorTable (used by LangChain integration
(as of August 2023).
Note: This is to be replaced by direct usage of the table-class-hierarchy classes.
"""
from typing import Any, Dict, List, Optional
from warnings import warn
from cassandra.cluster import ResponseFuture
from cassio.table.table_types import RowType
from cassio.table.tables import MetadataVectorCassandraTable
new_columns_to_legacy = {
"row_id": "document_id",
"body_blob": "document",
"vector": "embedding_vector",
}
legacy_columns_to_new = {v: k for k, v in new_columns_to_legacy.items()}
class VectorTable:
"""
This class is a rewriting of the VectorTable created for use in LangChain
integration, this time relying on the class-table-hierarchy (cassio.table.*).
It mostly provides a translation layer between parameters and key names,
using a metadata+vector table class internally.
Additional kwargs, for use in this new table class, are passed as they are
in order to enable their usage already before adapting the LangChain
integration code.
"""
DEPRECATION_MESSAGE = (
"Class `VectorTable` is a legacy construct and "
"will be deprecated in future versions of CassIO."
)
def __init__(self, *pargs: Any, **kwargs: Any):
#
warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
#
if "embedding_dimension" in kwargs:
vector_dimension = kwargs["embedding_dimension"]
new_kwargs = {
**{
k: v
for k, v in kwargs.items()
if k != "embedding_dimension"
# let's get rid of the infamous 'auto_id' here:
if k != "auto_id"
},
**{"vector_dimension": vector_dimension},
}
else:
new_kwargs = kwargs
# this legacy VectorTable will have everything indexed for search:
md_kwargs = {
**{"metadata_indexing": "all"},
**new_kwargs,
}
#
self.table = MetadataVectorCassandraTable(*pargs, **md_kwargs)
def search(
self,
embedding_vector: List[float],
top_k: int,
metric: str = "cos",
metric_threshold: Optional[float] = None,
**kwargs: Any,
) -> List[RowType]:
# get rows by ANN
enriched_hits = self.table.metric_ann_search(
vector=embedding_vector,
n=top_k,
metric=metric,
metric_threshold=metric_threshold,
**kwargs,
)
#
return [self._make_dict_legacy(rich_hit) for rich_hit in enriched_hits]
def put(
self,
document: str,
embedding_vector: List[float],
document_id: Any,
metadata: Optional[Dict[str, Any]] = None,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> None:
self.table.put(
row_id=document_id,
body_blob=document,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
**kwargs,
)
def put_async(
self,
document: str,
embedding_vector: List[float],
document_id: Any,
metadata: Optional[Dict[str, Any]] = None,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> ResponseFuture:
return self.table.put_async(
row_id=document_id,
body_blob=document,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
**kwargs,
)
def get(self, document_id: Any, **kwargs: Any) -> Optional[RowType]:
row_or_none = self.table.get(row_id=document_id, **kwargs)
if row_or_none:
return self._make_dict_legacy(row_or_none)
else:
return row_or_none
def delete(self, document_id: Any, **kwargs: Any) -> None:
self.table.delete(row_id=document_id, **kwargs)
return None
def clear(self) -> None:
self.table.clear()
return None
@staticmethod
def _make_dict_legacy(new_dict: RowType) -> RowType:
return {new_columns_to_legacy.get(k, k): v for k, v in new_dict.items()}
|