Spaces:
Running
Running
""" | |
Compatibility layer for legacy VectorTable (used by LangChain integration | |
(as of August 2023). | |
Note: This is to be replaced by direct usage of the table-class-hierarchy classes. | |
""" | |
from typing import Any, Dict, List, Optional | |
from warnings import warn | |
from cassandra.cluster import ResponseFuture | |
from cassio.table.table_types import RowType | |
from cassio.table.tables import MetadataVectorCassandraTable | |
new_columns_to_legacy = { | |
"row_id": "document_id", | |
"body_blob": "document", | |
"vector": "embedding_vector", | |
} | |
legacy_columns_to_new = {v: k for k, v in new_columns_to_legacy.items()} | |
class VectorTable: | |
""" | |
This class is a rewriting of the VectorTable created for use in LangChain | |
integration, this time relying on the class-table-hierarchy (cassio.table.*). | |
It mostly provides a translation layer between parameters and key names, | |
using a metadata+vector table class internally. | |
Additional kwargs, for use in this new table class, are passed as they are | |
in order to enable their usage already before adapting the LangChain | |
integration code. | |
""" | |
DEPRECATION_MESSAGE = ( | |
"Class `VectorTable` is a legacy construct and " | |
"will be deprecated in future versions of CassIO." | |
) | |
def __init__(self, *pargs: Any, **kwargs: Any): | |
# | |
warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) | |
# | |
if "embedding_dimension" in kwargs: | |
vector_dimension = kwargs["embedding_dimension"] | |
new_kwargs = { | |
**{ | |
k: v | |
for k, v in kwargs.items() | |
if k != "embedding_dimension" | |
# let's get rid of the infamous 'auto_id' here: | |
if k != "auto_id" | |
}, | |
**{"vector_dimension": vector_dimension}, | |
} | |
else: | |
new_kwargs = kwargs | |
# this legacy VectorTable will have everything indexed for search: | |
md_kwargs = { | |
**{"metadata_indexing": "all"}, | |
**new_kwargs, | |
} | |
# | |
self.table = MetadataVectorCassandraTable(*pargs, **md_kwargs) | |
def search( | |
self, | |
embedding_vector: List[float], | |
top_k: int, | |
metric: str = "cos", | |
metric_threshold: Optional[float] = None, | |
**kwargs: Any, | |
) -> List[RowType]: | |
# get rows by ANN | |
enriched_hits = self.table.metric_ann_search( | |
vector=embedding_vector, | |
n=top_k, | |
metric=metric, | |
metric_threshold=metric_threshold, | |
**kwargs, | |
) | |
# | |
return [self._make_dict_legacy(rich_hit) for rich_hit in enriched_hits] | |
def put( | |
self, | |
document: str, | |
embedding_vector: List[float], | |
document_id: Any, | |
metadata: Optional[Dict[str, Any]] = None, | |
ttl_seconds: Optional[int] = None, | |
**kwargs: Any, | |
) -> None: | |
self.table.put( | |
row_id=document_id, | |
body_blob=document, | |
vector=embedding_vector, | |
metadata=metadata or {}, | |
ttl_seconds=ttl_seconds, | |
**kwargs, | |
) | |
def put_async( | |
self, | |
document: str, | |
embedding_vector: List[float], | |
document_id: Any, | |
metadata: Optional[Dict[str, Any]] = None, | |
ttl_seconds: Optional[int] = None, | |
**kwargs: Any, | |
) -> ResponseFuture: | |
return self.table.put_async( | |
row_id=document_id, | |
body_blob=document, | |
vector=embedding_vector, | |
metadata=metadata or {}, | |
ttl_seconds=ttl_seconds, | |
**kwargs, | |
) | |
def get(self, document_id: Any, **kwargs: Any) -> Optional[RowType]: | |
row_or_none = self.table.get(row_id=document_id, **kwargs) | |
if row_or_none: | |
return self._make_dict_legacy(row_or_none) | |
else: | |
return row_or_none | |
def delete(self, document_id: Any, **kwargs: Any) -> None: | |
self.table.delete(row_id=document_id, **kwargs) | |
return None | |
def clear(self) -> None: | |
self.table.clear() | |
return None | |
def _make_dict_legacy(new_dict: RowType) -> RowType: | |
return {new_columns_to_legacy.get(k, k): v for k, v in new_dict.items()} | |