import urllib from http import HTTPStatus from typing import Any import requests from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain_core.tools import StructuredTool from langflow.base.langchain_utilities.model import LCToolComponent from langflow.io import DictInput, IntInput, SecretStrInput, StrInput from langflow.schema import Data class AstraDBCQLToolComponent(LCToolComponent): display_name: str = "Astra DB CQL" description: str = "Create a tool to get transactional data from DataStax Astra DB CQL Table" documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-cql-tool" icon: str = "AstraDB" inputs = [ StrInput(name="tool_name", display_name="Tool Name", info="The name of the tool.", required=True), StrInput( name="tool_description", display_name="Tool Description", info="The tool description to be passed to the model.", required=True, ), StrInput( name="keyspace", display_name="Keyspace", value="default_keyspace", info="The keyspace name within Astra DB where the data is stored.", required=True, advanced=True, ), StrInput( name="table_name", display_name="Table Name", info="The name of the table within Astra DB where the data is stored.", required=True, ), SecretStrInput( name="token", display_name="Astra DB Application Token", info="Authentication token for accessing Astra DB.", value="ASTRA_DB_APPLICATION_TOKEN", required=True, ), StrInput( name="api_endpoint", display_name="API Endpoint", info="API endpoint URL for the Astra DB service.", value="ASTRA_DB_API_ENDPOINT", required=True, ), StrInput( name="projection_fields", display_name="Projection fields", info="Attributes to return separated by comma.", required=True, value="*", advanced=True, ), DictInput( name="partition_keys", display_name="Partition Keys", is_list=True, info="Field name and description to the model", required=True, ), DictInput( name="clustering_keys", display_name="Clustering Keys", is_list=True, info="Field name and description to the model", ), DictInput( name="static_filters", display_name="Static Filters", is_list=True, advanced=True, info="Field name and value. When filled, it will not be generated by the LLM.", ), IntInput( name="number_of_results", display_name="Number of Results", info="Number of results to return.", advanced=True, value=5, ), ] def astra_rest(self, args): headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"} astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/" key = [] # Partition keys are mandatory for k in self.partition_keys: if k in args: key.append(args[k]) elif self.static_filters[k] is not None: key.append(self.static_filters[k]) else: # TO-DO: Raise error - Missing information key.append("none") # Clustering keys are optional for k in self.clustering_keys: if k in args: key.append(args[k]) elif self.static_filters[k] is not None: key.append(self.static_filters[k]) url = f'{astra_url}{"/".join(key)}?page-size={self.number_of_results}' if self.projection_fields != "*": url += f'&fields={urllib.parse.quote(self.projection_fields.replace(" ", ""))}' res = requests.request("GET", url=url, headers=headers, timeout=10) if int(res.status_code) >= HTTPStatus.BAD_REQUEST: return res.text try: res_data = res.json() return res_data["data"] except ValueError: return res.status_code def create_args_schema(self) -> dict[str, BaseModel]: args: dict[str, tuple[Any, Field]] = {} for key in self.partition_keys: # Partition keys are mandatory is it doesn't have a static filter if key not in self.static_filters: args[key] = (str, Field(description=self.partition_keys[key])) for key in self.clustering_keys: # Partition keys are mandatory if has the exclamation mark and doesn't have a static filter if key not in self.static_filters: if key.startswith("!"): # Mandatory args[key[1:]] = (str, Field(description=self.clustering_keys[key])) else: # Optional args[key] = (str | None, Field(description=self.clustering_keys[key], default=None)) model = create_model("ToolInput", **args, __base__=BaseModel) return {"ToolInput": model} def build_tool(self) -> StructuredTool: """Builds a Astra DB CQL Table tool. Args: name (str, optional): The name of the tool. Returns: Tool: The built AstraDB tool. """ schema_dict = self.create_args_schema() return StructuredTool.from_function( name=self.tool_name, args_schema=schema_dict["ToolInput"], description=self.tool_description, func=self.run_model, return_direct=False, ) def projection_args(self, input_str: str) -> dict: elements = input_str.split(",") result = {} for element in elements: if element.startswith("!"): result[element[1:]] = False else: result[element] = True return result def run_model(self, **args) -> Data | list[Data]: results = self.astra_rest(args) data: list[Data] = [Data(data=doc) for doc in results] self.status = data return results