Tai Truong
fix readme
d202ada
import os
from typing import Any
from astrapy import Collection, DataAPIClient, Database
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.tools import StructuredTool
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.io import DictInput, IntInput, SecretStrInput, StrInput
from langflow.schema import Data
class AstraDBToolComponent(LCToolComponent):
display_name: str = "Astra DB Tool"
description: str = "Create a tool to get transactional data from DataStax Astra DB Collection"
documentation: str = "https://docs.langflow.org/Components/components-tools#astra-db-tool"
icon: str = "AstraDB"
inputs = [
StrInput(
name="tool_name",
display_name="Tool Name",
info="The name of the tool.",
required=True,
),
StrInput(
name="tool_description",
display_name="Tool Description",
info="The description of the tool.",
required=True,
),
StrInput(
name="namespace",
display_name="Namespace Name",
info="The name of the namespace within Astra where the collection is be stored.",
value="default_keyspace",
advanced=True,
),
StrInput(
name="collection_name",
display_name="Collection Name",
info="The name of the collection within Astra DB where the vectors will be stored.",
required=True,
),
SecretStrInput(
name="token",
display_name="Astra DB Application Token",
info="Authentication token for accessing Astra DB.",
value="ASTRA_DB_APPLICATION_TOKEN",
required=True,
),
SecretStrInput(
name="api_endpoint",
display_name="Database" if os.getenv("ASTRA_ENHANCED", "false").lower() == "true" else "API Endpoint",
info="API endpoint URL for the Astra DB service.",
value="ASTRA_DB_API_ENDPOINT",
required=True,
),
StrInput(
name="projection_attributes",
display_name="Projection Attributes",
info="Attributes to return separated by comma.",
required=True,
value="*",
advanced=True,
),
DictInput(
name="tool_params",
info="Attributes to filter and description to the model. Add ! for mandatory (e.g: !customerId)",
display_name="Tool params",
is_list=True,
),
DictInput(
name="static_filters",
info="Attributes to filter and correspoding value",
display_name="Static filters",
advanced=True,
is_list=True,
),
IntInput(
name="number_of_results",
display_name="Number of Results",
info="Number of results to return.",
advanced=True,
value=5,
),
]
_cached_client: DataAPIClient | None = None
_cached_db: Database | None = None
_cached_collection: Collection | None = None
def _build_collection(self):
if self._cached_collection:
return self._cached_collection
cached_client = DataAPIClient(self.token)
cached_db = cached_client.get_database(self.api_endpoint, namespace=self.namespace)
self._cached_collection = cached_db.get_collection(self.collection_name)
return self._cached_collection
def create_args_schema(self) -> dict[str, BaseModel]:
args: dict[str, tuple[Any, Field] | list[str]] = {}
for key in self.tool_params:
if key.startswith("!"): # Mandatory
args[key[1:]] = (str, Field(description=self.tool_params[key]))
else: # Optional
args[key] = (str | None, Field(description=self.tool_params[key], default=None))
model = create_model("ToolInput", **args, __base__=BaseModel)
return {"ToolInput": model}
def build_tool(self) -> StructuredTool:
"""Builds an Astra DB Collection tool.
Returns:
Tool: The built Astra DB tool.
"""
schema_dict = self.create_args_schema()
tool = StructuredTool.from_function(
name=self.tool_name,
args_schema=schema_dict["ToolInput"],
description=self.tool_description,
func=self.run_model,
return_direct=False,
)
self.status = "Astra DB Tool created"
return tool
def projection_args(self, input_str: str) -> dict:
elements = input_str.split(",")
result = {}
for element in elements:
if element.startswith("!"):
result[element[1:]] = False
else:
result[element] = True
return result
def run_model(self, **args) -> Data | list[Data]:
collection = self._build_collection()
results = collection.find(
({**args, **self.static_filters}),
projection=self.projection_args(self.projection_attributes),
limit=self.number_of_results,
)
data: list[Data] = [Data(data=doc) for doc in results]
self.status = data
return data