Spaces:
Running
Running
from langchain_community.vectorstores import Cassandra | |
from loguru import logger | |
from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store | |
from langflow.helpers.data import docs_to_data | |
from langflow.inputs import BoolInput, DictInput, FloatInput | |
from langflow.io import ( | |
DataInput, | |
DropdownInput, | |
HandleInput, | |
IntInput, | |
MessageTextInput, | |
MultilineInput, | |
SecretStrInput, | |
) | |
from langflow.schema import Data | |
class CassandraVectorStoreComponent(LCVectorStoreComponent): | |
display_name = "Cassandra" | |
description = "Cassandra Vector Store with search capabilities" | |
documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra" | |
name = "Cassandra" | |
icon = "Cassandra" | |
inputs = [ | |
MessageTextInput( | |
name="database_ref", | |
display_name="Contact Points / Astra Database ID", | |
info="Contact points for the database (or AstraDB database ID)", | |
required=True, | |
), | |
MessageTextInput( | |
name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)." | |
), | |
SecretStrInput( | |
name="token", | |
display_name="Password / AstraDB Token", | |
info="User password for the database (or AstraDB token).", | |
required=True, | |
), | |
MessageTextInput( | |
name="keyspace", | |
display_name="Keyspace", | |
info="Table Keyspace (or AstraDB namespace).", | |
required=True, | |
), | |
MessageTextInput( | |
name="table_name", | |
display_name="Table Name", | |
info="The name of the table (or AstraDB collection) where vectors will be stored.", | |
required=True, | |
), | |
IntInput( | |
name="ttl_seconds", | |
display_name="TTL Seconds", | |
info="Optional time-to-live for the added texts.", | |
advanced=True, | |
), | |
IntInput( | |
name="batch_size", | |
display_name="Batch Size", | |
info="Optional number of data to process in a single batch.", | |
value=16, | |
advanced=True, | |
), | |
DropdownInput( | |
name="setup_mode", | |
display_name="Setup Mode", | |
info="Configuration mode for setting up the Cassandra table, with options like 'Sync', 'Async', or 'Off'.", | |
options=["Sync", "Async", "Off"], | |
value="Sync", | |
advanced=True, | |
), | |
DictInput( | |
name="cluster_kwargs", | |
display_name="Cluster arguments", | |
info="Optional dictionary of additional keyword arguments for the Cassandra cluster.", | |
advanced=True, | |
is_list=True, | |
), | |
MultilineInput(name="search_query", display_name="Search Query"), | |
DataInput( | |
name="ingest_data", | |
display_name="Ingest Data", | |
is_list=True, | |
), | |
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]), | |
IntInput( | |
name="number_of_results", | |
display_name="Number of Results", | |
info="Number of results to return.", | |
value=4, | |
advanced=True, | |
), | |
DropdownInput( | |
name="search_type", | |
display_name="Search Type", | |
info="Search type to use", | |
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], | |
value="Similarity", | |
advanced=True, | |
), | |
FloatInput( | |
name="search_score_threshold", | |
display_name="Search Score Threshold", | |
info="Minimum similarity score threshold for search results. " | |
"(when using 'Similarity with score threshold')", | |
value=0, | |
advanced=True, | |
), | |
DictInput( | |
name="search_filter", | |
display_name="Search Metadata Filter", | |
info="Optional dictionary of filters to apply to the search query.", | |
advanced=True, | |
is_list=True, | |
), | |
MessageTextInput( | |
name="body_search", | |
display_name="Search Body", | |
info="Document textual search terms to apply to the search query.", | |
advanced=True, | |
), | |
BoolInput( | |
name="enable_body_search", | |
display_name="Enable Body Search", | |
info="Flag to enable body search. This must be enabled BEFORE the table is created.", | |
value=False, | |
advanced=True, | |
), | |
] | |
def build_vector_store(self) -> Cassandra: | |
try: | |
import cassio | |
from langchain_community.utilities.cassandra import SetupMode | |
except ImportError as e: | |
msg = "Could not import cassio integration package. Please install it with `pip install cassio`." | |
raise ImportError(msg) from e | |
from uuid import UUID | |
database_ref = self.database_ref | |
try: | |
UUID(self.database_ref) | |
is_astra = True | |
except ValueError: | |
is_astra = False | |
if "," in self.database_ref: | |
# use a copy because we can't change the type of the parameter | |
database_ref = self.database_ref.split(",") | |
if is_astra: | |
cassio.init( | |
database_id=database_ref, | |
token=self.token, | |
cluster_kwargs=self.cluster_kwargs, | |
) | |
else: | |
cassio.init( | |
contact_points=database_ref, | |
username=self.username, | |
password=self.token, | |
cluster_kwargs=self.cluster_kwargs, | |
) | |
documents = [] | |
for _input in self.ingest_data or []: | |
if isinstance(_input, Data): | |
documents.append(_input.to_lc_document()) | |
else: | |
documents.append(_input) | |
body_index_options = [("index_analyzer", "STANDARD")] if self.enable_body_search else None | |
if self.setup_mode == "Off": | |
setup_mode = SetupMode.OFF | |
elif self.setup_mode == "Sync": | |
setup_mode = SetupMode.SYNC | |
else: | |
setup_mode = SetupMode.ASYNC | |
if documents: | |
logger.debug(f"Adding {len(documents)} documents to the Vector Store.") | |
table = Cassandra.from_documents( | |
documents=documents, | |
embedding=self.embedding, | |
table_name=self.table_name, | |
keyspace=self.keyspace, | |
ttl_seconds=self.ttl_seconds or None, | |
batch_size=self.batch_size, | |
body_index_options=body_index_options, | |
) | |
else: | |
logger.debug("No documents to add to the Vector Store.") | |
table = Cassandra( | |
embedding=self.embedding, | |
table_name=self.table_name, | |
keyspace=self.keyspace, | |
ttl_seconds=self.ttl_seconds or None, | |
body_index_options=body_index_options, | |
setup_mode=setup_mode, | |
) | |
return table | |
def _map_search_type(self) -> str: | |
if self.search_type == "Similarity with score threshold": | |
return "similarity_score_threshold" | |
if self.search_type == "MMR (Max Marginal Relevance)": | |
return "mmr" | |
return "similarity" | |
def search_documents(self) -> list[Data]: | |
vector_store = self.build_vector_store() | |
logger.debug(f"Search input: {self.search_query}") | |
logger.debug(f"Search type: {self.search_type}") | |
logger.debug(f"Number of results: {self.number_of_results}") | |
if self.search_query and isinstance(self.search_query, str) and self.search_query.strip(): | |
try: | |
search_type = self._map_search_type() | |
search_args = self._build_search_args() | |
logger.debug(f"Search args: {search_args}") | |
docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args) | |
except KeyError as e: | |
if "content" in str(e): | |
msg = ( | |
"You should ingest data through Langflow (or LangChain) to query it in Langflow. " | |
"Your collection does not contain a field name 'content'." | |
) | |
raise ValueError(msg) from e | |
raise | |
logger.debug(f"Retrieved documents: {len(docs)}") | |
data = docs_to_data(docs) | |
self.status = data | |
return data | |
return [] | |
def _build_search_args(self): | |
args = { | |
"k": self.number_of_results, | |
"score_threshold": self.search_score_threshold, | |
} | |
if self.search_filter: | |
clean_filter = {k: v for k, v in self.search_filter.items() if k and v} | |
if len(clean_filter) > 0: | |
args["filter"] = clean_filter | |
if self.body_search: | |
if not self.enable_body_search: | |
msg = "You should enable body search when creating the table to search the body field." | |
raise ValueError(msg) | |
args["body_search"] = self.body_search | |
return args | |
def get_retriever_kwargs(self): | |
search_args = self._build_search_args() | |
return { | |
"search_type": self._map_search_type(), | |
"search_kwargs": search_args, | |
} | |