import os from collections import defaultdict import orjson from astrapy import DataAPIClient from astrapy.admin import parse_api_endpoint from langchain_astradb import AstraDBVectorStore from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store from langflow.helpers import docs_to_data from langflow.inputs import DictInput, FloatInput, MessageTextInput, NestedDictInput from langflow.io import ( BoolInput, DataInput, DropdownInput, HandleInput, IntInput, MultilineInput, SecretStrInput, StrInput, ) from langflow.schema import Data from langflow.utils.version import get_version_info class AstraDBVectorStoreComponent(LCVectorStoreComponent): display_name: str = "Astra DB" description: str = "Implementation of Vector Store using Astra DB with search capabilities" documentation: str = "https://docs.langflow.org/starter-projects-vector-store-rag" name = "AstraDB" icon: str = "AstraDB" _cached_vector_store: AstraDBVectorStore | None = None VECTORIZE_PROVIDERS_MAPPING = defaultdict( list, { "Azure OpenAI": [ "azureOpenAI", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], ], "Hugging Face - Dedicated": ["huggingfaceDedicated", ["endpoint-defined-model"]], "Hugging Face - Serverless": [ "huggingface", [ "sentence-transformers/all-MiniLM-L6-v2", "intfloat/multilingual-e5-large", "intfloat/multilingual-e5-large-instruct", "BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5", ], ], "Jina AI": [ "jinaAI", [ "jina-embeddings-v2-base-en", "jina-embeddings-v2-base-de", "jina-embeddings-v2-base-es", "jina-embeddings-v2-base-code", "jina-embeddings-v2-base-zh", ], ], "Mistral AI": ["mistral", ["mistral-embed"]], "Nvidia": ["nvidia", ["NV-Embed-QA"]], "OpenAI": ["openai", ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]], "Upstage": ["upstageAI", ["solar-embedding-1-large"]], "Voyage AI": [ "voyageAI", ["voyage-large-2-instruct", "voyage-law-2", "voyage-code-2", "voyage-large-2", "voyage-2"], ], }, ) inputs = [ SecretStrInput( name="token", display_name="Astra DB Application Token", info="Authentication token for accessing Astra DB.", value="ASTRA_DB_APPLICATION_TOKEN", required=True, advanced=os.getenv("ASTRA_ENHANCED", "false").lower() == "true", real_time_refresh=True, ), SecretStrInput( name="api_endpoint", display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint", info="API endpoint URL for the Astra DB service.", value="ASTRA_DB_API_ENDPOINT", required=True, real_time_refresh=True, ), DropdownInput( name="collection_name", display_name="Collection", info="The name of the collection within Astra DB where the vectors will be stored.", required=True, refresh_button=True, real_time_refresh=True, options=["+ Create new collection"], value="+ Create new collection", ), StrInput( name="collection_name_new", display_name="Collection Name", info="Name of the new collection to create.", advanced=os.getenv("LANGFLOW_HOST") is not None, required=os.getenv("LANGFLOW_HOST") is None, ), StrInput( name="keyspace", display_name="Keyspace", info="Optional keyspace within Astra DB to use for the collection.", advanced=True, ), MultilineInput( name="search_input", display_name="Search Input", ), IntInput( name="number_of_results", display_name="Number of Results", info="Number of results to return.", advanced=True, value=4, ), DropdownInput( name="search_type", display_name="Search Type", info="Search type to use", options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"], value="Similarity", advanced=True, ), FloatInput( name="search_score_threshold", display_name="Search Score Threshold", info="Minimum similarity score threshold for search results. " "(when using 'Similarity with score threshold')", value=0, advanced=True, ), NestedDictInput( name="advanced_search_filter", display_name="Search Metadata Filter", info="Optional dictionary of filters to apply to the search query.", advanced=True, ), DictInput( name="search_filter", display_name="[DEPRECATED] Search Metadata Filter", info="Deprecated: use advanced_search_filter. Optional dictionary of filters to apply to the search query.", advanced=True, list=True, ), DataInput( name="ingest_data", display_name="Ingest Data", ), DropdownInput( name="embedding_choice", display_name="Embedding Model or Astra Vectorize", info="Determines whether to use Astra Vectorize for the collection.", options=["Embedding Model", "Astra Vectorize"], real_time_refresh=True, value="Embedding Model", ), HandleInput( name="embedding_model", display_name="Embedding Model", input_types=["Embeddings"], info="Allows an embedding model configuration.", ), DropdownInput( name="metric", display_name="Metric", info="Optional distance metric for vector comparisons in the vector store.", options=["cosine", "dot_product", "euclidean"], value="cosine", advanced=True, ), IntInput( name="batch_size", display_name="Batch Size", info="Optional number of data to process in a single batch.", advanced=True, ), IntInput( name="bulk_insert_batch_concurrency", display_name="Bulk Insert Batch Concurrency", info="Optional concurrency level for bulk insert operations.", advanced=True, ), IntInput( name="bulk_insert_overwrite_concurrency", display_name="Bulk Insert Overwrite Concurrency", info="Optional concurrency level for bulk insert operations that overwrite existing data.", advanced=True, ), IntInput( name="bulk_delete_concurrency", display_name="Bulk Delete Concurrency", info="Optional concurrency level for bulk delete operations.", advanced=True, ), DropdownInput( name="setup_mode", display_name="Setup Mode", info="Configuration mode for setting up the vector store, with options like 'Sync' or 'Off'.", options=["Sync", "Off"], advanced=True, value="Sync", ), BoolInput( name="pre_delete_collection", display_name="Pre Delete Collection", info="Boolean flag to determine whether to delete the collection before creating a new one.", advanced=True, ), StrInput( name="metadata_indexing_include", display_name="Metadata Indexing Include", info="Optional list of metadata fields to include in the indexing.", list=True, advanced=True, ), StrInput( name="metadata_indexing_exclude", display_name="Metadata Indexing Exclude", info="Optional list of metadata fields to exclude from the indexing.", list=True, advanced=True, ), StrInput( name="collection_indexing_policy", display_name="Collection Indexing Policy", info='Optional JSON string for the "indexing" field of the collection. ' "See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option", advanced=True, ), ] def del_fields(self, build_config, field_list): for field in field_list: if field in build_config: del build_config[field] return build_config def insert_in_dict(self, build_config, field_name, new_parameters): # Insert the new key-value pair after the found key for new_field_name, new_parameter in new_parameters.items(): # Get all the items as a list of tuples (key, value) items = list(build_config.items()) # Find the index of the key to insert after idx = len(items) for i, (key, _) in enumerate(items): if key == field_name: idx = i + 1 break items.insert(idx, (new_field_name, new_parameter)) # Clear the original dictionary and update with the modified items build_config.clear() build_config.update(items) return build_config def update_providers_mapping(self): # If we don't have token or api_endpoint, we can't fetch the list of providers if not self.token or not self.api_endpoint: self.log("Astra DB token and API endpoint are required to fetch the list of Vectorize providers.") return self.VECTORIZE_PROVIDERS_MAPPING try: self.log("Dynamically updating list of Vectorize providers.") # Get the admin object client = DataAPIClient(token=self.token) admin = client.get_admin() # Get the embedding providers db_admin = admin.get_database_admin(self.api_endpoint) embedding_providers = db_admin.find_embedding_providers().as_dict() vectorize_providers_mapping = {} # Map the provider display name to the provider key and models for provider_key, provider_data in embedding_providers["embeddingProviders"].items(): display_name = provider_data["displayName"] models = [model["name"] for model in provider_data["models"]] vectorize_providers_mapping[display_name] = [provider_key, models] # Sort the resulting dictionary return defaultdict(list, dict(sorted(vectorize_providers_mapping.items()))) except Exception as e: # noqa: BLE001 self.log(f"Error fetching Vectorize providers: {e}") return self.VECTORIZE_PROVIDERS_MAPPING def get_database(self): try: client = DataAPIClient(token=self.token) return client.get_database( self.api_endpoint, token=self.token, ) except Exception as e: # noqa: BLE001 self.log(f"Error getting database: {e}") return None def _initialize_collection_options(self): database = self.get_database() if database is None: return ["+ Create new collection"] try: collections = [collection.name for collection in database.list_collections()] except Exception as e: # noqa: BLE001 self.log(f"Error fetching collections: {e}") return ["+ Create new collection"] return [*collections, "+ Create new collection"] def get_collection_choice(self): collection_name = self.collection_name if collection_name == "+ Create new collection": return self.collection_name_new return collection_name def get_collection_options(self): # Only get the options if the collection exists database = self.get_database() if database is None: return None collection_name = self.get_collection_choice() try: collection = database.get_collection(collection_name) collection_options = collection.options() except Exception as _: # noqa: BLE001 return None return collection_options.vector def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): # Refresh the collection name options build_config["collection_name"]["options"] = self._initialize_collection_options() # If the collection name is set to "+ Create new collection", show embedding choice if field_name == "collection_name" and field_value == "+ Create new collection": build_config["embedding_choice"]["advanced"] = False build_config["embedding_choice"]["value"] = "Embedding Model" build_config["embedding_model"]["advanced"] = False build_config["collection_name_new"]["advanced"] = False build_config["collection_name_new"]["required"] = True # But if it's not, hide embedding choice elif field_name == "collection_name" and field_value != "+ Create new collection": build_config["embedding_choice"]["advanced"] = True build_config["collection_name_new"]["advanced"] = True build_config["collection_name_new"]["required"] = False build_config["collection_name_new"]["value"] = "" # Get the collection options for the selected collection collection_options = self.get_collection_options() # If the collection options are available (DB exists), show the advanced options if collection_options: build_config["embedding_choice"]["advanced"] = True if collection_options.service: self.del_fields( build_config, [ "embedding_provider", "model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication", ], ) build_config["embedding_model"]["advanced"] = True build_config["embedding_choice"]["value"] = "Astra Vectorize" else: build_config["embedding_model"]["advanced"] = False build_config["embedding_provider"]["advanced"] = False build_config["embedding_choice"]["value"] = "Embedding Model" elif field_name == "embedding_choice": if field_value == "Astra Vectorize": build_config["embedding_model"]["advanced"] = True # Update the providers mapping vectorize_providers = self.update_providers_mapping() new_parameter = DropdownInput( name="embedding_provider", display_name="Embedding Provider", options=vectorize_providers.keys(), value="", required=True, real_time_refresh=True, ).to_dict() self.insert_in_dict(build_config, "embedding_choice", {"embedding_provider": new_parameter}) else: build_config["embedding_model"]["advanced"] = False self.del_fields( build_config, [ "embedding_provider", "model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication", ], ) elif field_name == "embedding_provider": self.del_fields( build_config, ["model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"], ) # Update the providers mapping vectorize_providers = self.update_providers_mapping() model_options = vectorize_providers[field_value][1] new_parameter = DropdownInput( name="model", display_name="Model", info="The embedding model to use for the selected provider. Each provider has a different set of " "models available (full list at " "https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html):\n\n" f"{', '.join(model_options)}", options=model_options, value=None, required=True, real_time_refresh=True, ).to_dict() self.insert_in_dict(build_config, "embedding_provider", {"model": new_parameter}) elif field_name == "model": self.del_fields( build_config, ["z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication"], ) new_parameter_1 = DictInput( name="z_01_model_parameters", display_name="Model Parameters", list=True, ).to_dict() new_parameter_2 = MessageTextInput( name="z_02_api_key_name", display_name="API Key Name", info="The name of the embeddings provider API key stored on Astra. " "If set, it will override the 'ProviderKey' in the authentication parameters.", ).to_dict() new_parameter_3 = SecretStrInput( load_from_db=False, name="z_03_provider_api_key", display_name="Provider API Key", info="An alternative to the Astra Authentication that passes an API key for the provider " "with each request to Astra DB. " "This may be used when Vectorize is configured for the collection, " "but no corresponding provider secret is stored within Astra's key management system.", ).to_dict() new_parameter_4 = DictInput( name="z_04_authentication", display_name="Authentication Parameters", list=True, ).to_dict() self.insert_in_dict( build_config, "model", { "z_01_model_parameters": new_parameter_1, "z_02_api_key_name": new_parameter_2, "z_03_provider_api_key": new_parameter_3, "z_04_authentication": new_parameter_4, }, ) return build_config def build_vectorize_options(self, **kwargs): for attribute in [ "embedding_provider", "model", "z_01_model_parameters", "z_02_api_key_name", "z_03_provider_api_key", "z_04_authentication", ]: if not hasattr(self, attribute): setattr(self, attribute, None) # Fetch values from kwargs if any self.* attributes are None provider_mapping = self.update_providers_mapping() provider_value = provider_mapping.get(self.embedding_provider, [None])[0] or kwargs.get("embedding_provider") model_name = self.model or kwargs.get("model") authentication = {**(self.z_04_authentication or {}), **kwargs.get("z_04_authentication", {})} parameters = self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}) # Set the API key name if provided api_key_name = self.z_02_api_key_name or kwargs.get("z_02_api_key_name") provider_key = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key") if api_key_name: authentication["providerKey"] = api_key_name if authentication: provider_key = None authentication["providerKey"] = authentication["providerKey"].split(".")[0] # Set authentication and parameters to None if no values are provided if not authentication: authentication = None if not parameters: parameters = None return { # must match astrapy.info.CollectionVectorServiceOptions "collection_vector_service_options": { "provider": provider_value, "modelName": model_name, "authentication": authentication, "parameters": parameters, }, "collection_embedding_api_key": provider_key, } @check_cached_vector_store def build_vector_store(self, vectorize_options=None): try: from langchain_astradb import AstraDBVectorStore from langchain_astradb.utils.astradb import SetupMode except ImportError as e: msg = ( "Could not import langchain Astra DB integration package. " "Please install it with `pip install langchain-astradb`." ) raise ImportError(msg) from e try: if not self.setup_mode: self.setup_mode = self._inputs["setup_mode"].options[0] setup_mode_value = SetupMode[self.setup_mode.upper()] except KeyError as e: msg = f"Invalid setup mode: {self.setup_mode}" raise ValueError(msg) from e metric_value = self.metric or None autodetect = False if self.embedding_choice == "Embedding Model": embedding_dict = {"embedding": self.embedding_model} # Use autodetect if the collection name is NOT set to "+ Create new collection" elif self.collection_name != "+ Create new collection": autodetect = True metric_value = None setup_mode_value = None embedding_dict = {} else: from astrapy.info import CollectionVectorServiceOptions # Grab the collection options if available collection_options = self.get_collection_options() # Ensure collection_options and its nested attributes are handled safely authentication = getattr(self, "z_04_authentication", {}) or ( collection_options.service.authentication if collection_options and collection_options.service and collection_options.service.authentication else {} ) # Build the vectorize options dictionary dict_options = vectorize_options or self.build_vectorize_options( embedding_provider=( getattr(self, "embedding_provider", None) or ( collection_options.service.provider if collection_options and collection_options.service else None ) ), model=( getattr(self, "model", None) or ( collection_options.service.model_name if collection_options and collection_options.service else None ) ), z_01_model_parameters=( getattr(self, "z_01_model_parameters", None) or ( collection_options.service.parameters if collection_options and collection_options.service else None ) ), z_02_api_key_name=( getattr(self, "z_02_api_key_name", None) or (authentication.get("apiKey") if authentication else None) ), z_03_provider_api_key=( getattr(self, "z_03_provider_api_key", None) or (authentication.get("providerKey") if authentication else None) ), z_04_authentication=authentication, ) # Set the embedding dictionary embedding_dict = { "collection_vector_service_options": CollectionVectorServiceOptions.from_dict( dict_options.get("collection_vector_service_options") ), "collection_embedding_api_key": dict_options.get("collection_embedding_api_key"), } # Get Langflow version and platform information __version__ = get_version_info()["version"] langflow_prefix = "" if os.getenv("LANGFLOW_HOST") is not None: langflow_prefix = "ds-" try: vector_store = AstraDBVectorStore( token=self.token, api_endpoint=self.api_endpoint, namespace=self.keyspace or None, collection_name=self.get_collection_choice(), autodetect_collection=autodetect, environment=( parse_api_endpoint(getattr(self, "api_endpoint", None)).environment if getattr(self, "api_endpoint", None) else None ), metric=metric_value, batch_size=self.batch_size or None, bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None, bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None, bulk_delete_concurrency=self.bulk_delete_concurrency or None, setup_mode=setup_mode_value, pre_delete_collection=self.pre_delete_collection, metadata_indexing_include=[s for s in self.metadata_indexing_include if s] or None, metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s] or None, collection_indexing_policy=orjson.dumps(self.collection_indexing_policy) if self.collection_indexing_policy else None, ext_callers=[(f"{langflow_prefix}langflow", __version__)], **embedding_dict, ) except Exception as e: msg = f"Error initializing AstraDBVectorStore: {e}" raise ValueError(msg) from e self._add_documents_to_vector_store(vector_store) return vector_store def _add_documents_to_vector_store(self, vector_store) -> None: documents = [] for _input in self.ingest_data or []: if isinstance(_input, Data): documents.append(_input.to_lc_document()) else: msg = "Vector Store Inputs must be Data objects." raise TypeError(msg) if documents: self.log(f"Adding {len(documents)} documents to the Vector Store.") try: vector_store.add_documents(documents) except Exception as e: msg = f"Error adding documents to AstraDBVectorStore: {e}" raise ValueError(msg) from e else: self.log("No documents to add to the Vector Store.") def _map_search_type(self) -> str: if self.search_type == "Similarity with score threshold": return "similarity_score_threshold" if self.search_type == "MMR (Max Marginal Relevance)": return "mmr" return "similarity" def _build_search_args(self): query = self.search_input if isinstance(self.search_input, str) and self.search_input.strip() else None search_filter = ( {k: v for k, v in self.search_filter.items() if k and v and k.strip()} if self.search_filter else None ) if query: args = { "query": query, "search_type": self._map_search_type(), "k": self.number_of_results, "score_threshold": self.search_score_threshold, } elif self.advanced_search_filter or search_filter: args = { "n": self.number_of_results, } else: return {} filter_arg = self.advanced_search_filter or {} if search_filter: self.log(self.log(f"`search_filter` is deprecated. Use `advanced_search_filter`. Cleaned: {search_filter}")) filter_arg.update(search_filter) if filter_arg: args["filter"] = filter_arg return args def search_documents(self, vector_store=None) -> list[Data]: vector_store = vector_store or self.build_vector_store() self.log(f"Search input: {self.search_input}") self.log(f"Search type: {self.search_type}") self.log(f"Number of results: {self.number_of_results}") try: search_args = self._build_search_args() except Exception as e: msg = f"Error in AstraDBVectorStore._build_search_args: {e}" raise ValueError(msg) from e if not search_args: self.log("No search input or filters provided. Skipping search.") return [] docs = [] search_method = "search" if "query" in search_args else "metadata_search" try: self.log(f"Calling vector_store.{search_method} with args: {search_args}") docs = getattr(vector_store, search_method)(**search_args) except Exception as e: msg = f"Error performing {search_method} in AstraDBVectorStore: {e}" raise ValueError(msg) from e self.log(f"Retrieved documents: {len(docs)}") data = docs_to_data(docs) self.log(f"Converted documents to data: {len(data)}") self.status = data return data def get_retriever_kwargs(self): search_args = self._build_search_args() return { "search_type": self._map_search_type(), "search_kwargs": search_args, }