from langchain_community.vectorstores import Cassandra from loguru import logger from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store from langflow.helpers.data import docs_to_data from langflow.inputs import BoolInput, DictInput, FloatInput from langflow.io import ( DataInput, DropdownInput, HandleInput, IntInput, MessageTextInput, MultilineInput, SecretStrInput, ) from langflow.schema import Data class CassandraVectorStoreComponent(LCVectorStoreComponent): display_name = "Cassandra" description = "Cassandra Vector Store with search capabilities" documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra" name = "Cassandra" icon = "Cassandra" inputs = [ MessageTextInput( name="database_ref", display_name="Contact Points / Astra Database ID", info="Contact points for the database (or AstraDB database ID)", required=True, ), MessageTextInput( name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)." ), SecretStrInput( name="token", display_name="Password / AstraDB Token", info="User password for the database (or AstraDB token).", required=True, ), MessageTextInput( name="keyspace", display_name="Keyspace", info="Table Keyspace (or AstraDB namespace).", required=True, ), MessageTextInput( name="table_name", display_name="Table Name", info="The name of the table (or AstraDB collection) where vectors will be stored.", required=True, ), IntInput( name="ttl_seconds", display_name="TTL Seconds", info="Optional time-to-live for the added texts.", advanced=True, ), IntInput( name="batch_size", display_name="Batch Size", info="Optional number of data to process in a single batch.", value=16, advanced=True, ), DropdownInput( name="setup_mode", display_name="Setup Mode", info="Configuration mode for setting up the Cassandra table, with options like 'Sync', 'Async', or 'Off'.", options=["Sync", "Async", "Off"], value="Sync", advanced=True, ), DictInput( name="cluster_kwargs", display_name="Cluster arguments", info="Optional dictionary of additional keyword arguments for the Cassandra cluster.", advanced=True, is_list=True, ), MultilineInput(name="search_query", display_name="Search Query"), DataInput( name="ingest_data", display_name="Ingest Data", is_list=True, ), HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]), IntInput( name="number_of_results", display_name="Number of Results", info="Number of results to return.", value=4, advanced=True, ), DropdownInput( name="search_type", display_name="Search Type", info="Search type to use", options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], value="Similarity", advanced=True, ), FloatInput( name="search_score_threshold", display_name="Search Score Threshold", info="Minimum similarity score threshold for search results. " "(when using 'Similarity with score threshold')", value=0, advanced=True, ), DictInput( name="search_filter", display_name="Search Metadata Filter", info="Optional dictionary of filters to apply to the search query.", advanced=True, is_list=True, ), MessageTextInput( name="body_search", display_name="Search Body", info="Document textual search terms to apply to the search query.", advanced=True, ), BoolInput( name="enable_body_search", display_name="Enable Body Search", info="Flag to enable body search. This must be enabled BEFORE the table is created.", value=False, advanced=True, ), ] @check_cached_vector_store def build_vector_store(self) -> Cassandra: try: import cassio from langchain_community.utilities.cassandra import SetupMode except ImportError as e: msg = "Could not import cassio integration package. Please install it with `pip install cassio`." raise ImportError(msg) from e from uuid import UUID database_ref = self.database_ref try: UUID(self.database_ref) is_astra = True except ValueError: is_astra = False if "," in self.database_ref: # use a copy because we can't change the type of the parameter database_ref = self.database_ref.split(",") if is_astra: cassio.init( database_id=database_ref, token=self.token, cluster_kwargs=self.cluster_kwargs, ) else: cassio.init( contact_points=database_ref, username=self.username, password=self.token, cluster_kwargs=self.cluster_kwargs, ) documents = [] for _input in self.ingest_data or []: if isinstance(_input, Data): documents.append(_input.to_lc_document()) else: documents.append(_input) body_index_options = [("index_analyzer", "STANDARD")] if self.enable_body_search else None if self.setup_mode == "Off": setup_mode = SetupMode.OFF elif self.setup_mode == "Sync": setup_mode = SetupMode.SYNC else: setup_mode = SetupMode.ASYNC if documents: logger.debug(f"Adding {len(documents)} documents to the Vector Store.") table = Cassandra.from_documents( documents=documents, embedding=self.embedding, table_name=self.table_name, keyspace=self.keyspace, ttl_seconds=self.ttl_seconds or None, batch_size=self.batch_size, body_index_options=body_index_options, ) else: logger.debug("No documents to add to the Vector Store.") table = Cassandra( embedding=self.embedding, table_name=self.table_name, keyspace=self.keyspace, ttl_seconds=self.ttl_seconds or None, body_index_options=body_index_options, setup_mode=setup_mode, ) return table def _map_search_type(self) -> str: if self.search_type == "Similarity with score threshold": return "similarity_score_threshold" if self.search_type == "MMR (Max Marginal Relevance)": return "mmr" return "similarity" def search_documents(self) -> list[Data]: vector_store = self.build_vector_store() logger.debug(f"Search input: {self.search_query}") logger.debug(f"Search type: {self.search_type}") logger.debug(f"Number of results: {self.number_of_results}") if self.search_query and isinstance(self.search_query, str) and self.search_query.strip(): try: search_type = self._map_search_type() search_args = self._build_search_args() logger.debug(f"Search args: {search_args}") docs = vector_store.search(query=self.search_query, search_type=search_type, **search_args) except KeyError as e: if "content" in str(e): msg = ( "You should ingest data through Langflow (or LangChain) to query it in Langflow. " "Your collection does not contain a field name 'content'." ) raise ValueError(msg) from e raise logger.debug(f"Retrieved documents: {len(docs)}") data = docs_to_data(docs) self.status = data return data return [] def _build_search_args(self): args = { "k": self.number_of_results, "score_threshold": self.search_score_threshold, } if self.search_filter: clean_filter = {k: v for k, v in self.search_filter.items() if k and v} if len(clean_filter) > 0: args["filter"] = clean_filter if self.body_search: if not self.enable_body_search: msg = "You should enable body search when creating the table to search the body field." raise ValueError(msg) args["body_search"] = self.body_search return args def get_retriever_kwargs(self): search_args = self._build_search_args() return { "search_type": self._map_search_type(), "search_kwargs": search_args, }