import os from typing import Any from astrapy import Collection, DataAPIClient, Database from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain_core.tools import StructuredTool from langflow.base.langchain_utilities.model import LCToolComponent from langflow.io import DictInput, IntInput, SecretStrInput, StrInput from langflow.schema import Data class AstraDBToolComponent(LCToolComponent): display_name: str = "Astra DB Tool" description: str = "Create a tool to get transactional data from DataStax Astra DB Collection" documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool" icon: str = "AstraDB" inputs = [ StrInput( name="tool_name", display_name="Tool Name", info="The name of the tool.", required=True, ), StrInput( name="tool_description", display_name="Tool Description", info="The description of the tool.", required=True, ), StrInput( name="namespace", display_name="Namespace Name", info="The name of the namespace within Astra where the collection is be stored.", value="default_keyspace", advanced=True, ), StrInput( name="collection_name", display_name="Collection Name", info="The name of the collection within Astra DB where the vectors will be stored.", required=True, ), SecretStrInput( name="token", display_name="Astra DB Application Token", info="Authentication token for accessing Astra DB.", value="ASTRA_DB_APPLICATION_TOKEN", required=True, ), SecretStrInput( name="api_endpoint", display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint", info="API endpoint URL for the Astra DB service.", value="ASTRA_DB_API_ENDPOINT", required=True, ), StrInput( name="projection_attributes", display_name="Projection Attributes", info="Attributes to return separated by comma.", required=True, value="*", advanced=True, ), DictInput( name="tool_params", info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)", display_name="Tool params", is_list=True, ), DictInput( name="static_filters", info="Attributes to filter and correspoding value", display_name="Static filters", advanced=True, is_list=True, ), IntInput( name="number_of_results", display_name="Number of Results", info="Number of results to return.", advanced=True, value=5, ), ] _cached_client: DataAPIClient | None = None _cached_db: Database | None = None _cached_collection: Collection | None = None def _build_collection(self): if self._cached_collection: return self._cached_collection cached_client = DataAPIClient(self.token) cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace) self._cached_collection = cached_db.get_collection(self.collection_name) return self._cached_collection def create_args_schema(self) -> dict[str, BaseModel]: args: dict[str, tuple[Any, Field] | list[str]] = {} for key in self.tool_params: if key.startswith("!"): # Mandatory args[key[1:]] = (str, Field(description=self.tool_params[key])) else: # Optional args[key] = (str | None, Field(description=self.tool_params[key], default=None)) model = create_model("ToolInput", **args, __base__=BaseModel) return {"ToolInput": model} def build_tool(self) -> StructuredTool: """Builds an Astra DB Collection tool. Returns: Tool: The built Astra DB tool. """ schema_dict = self.create_args_schema() tool = StructuredTool.from_function( name=self.tool_name, args_schema=schema_dict["ToolInput"], description=self.tool_description, func=self.run_model, return_direct=False, ) self.status = "Astra DB Tool created" return tool def projection_args(self, input_str: str) -> dict: elements = input_str.split(",") result = {} for element in elements: if element.startswith("!"): result[element[1:]] = False else: result[element] = True return result def run_model(self, **args) -> Data | list[Data]: collection = self._build_collection() results = collection.find( ({**args, **self.static_filters}), projection=self.projection_args(self.projection_attributes), limit=self.number_of_results, ) data: list[Data] = [Data(data=doc) for doc in results] self.status = data return data