# Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio import deprecation import httpx import logging import json import threading from collections.abc import ( AsyncGenerator, AsyncIterator, Generator, Iterator, ) from concurrent.futures import ThreadPoolExecutor from functools import partial from queue import Queue from types import TracebackType from typing import ( Any, cast, Dict, List, Optional, Tuple, Union, Type, ) from astrapy import __version__ from astrapy.core.api import APIRequestError, api_request, async_api_request from astrapy.core.defaults import ( DEFAULT_AUTH_HEADER, DEFAULT_JSON_API_PATH, DEFAULT_JSON_API_VERSION, DEFAULT_KEYSPACE_NAME, MAX_INSERT_NUM_DOCUMENTS, ) from astrapy.core.utils import ( convert_vector_to_floats, make_payload, normalize_for_api, restore_from_api, http_methods, to_httpx_timeout, TimeoutInfoWideType, ) from astrapy.core.core_types import ( API_DOC, API_RESPONSE, PaginableRequestMethod, AsyncPaginableRequestMethod, ) logger = logging.getLogger(__name__) class AstraDBCollection: # Initialize the shared httpx client as a class attribute client = httpx.Client() def __init__( self, collection_name: str, astra_db: Optional[AstraDB] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: """ Initialize an AstraDBCollection instance. Args: collection_name (str): The name of the collection. astra_db (AstraDB, optional): An instance of Astra DB. token (str, optional): Authentication token for Astra DB. api_endpoint (str, optional): API endpoint URL. namespace (str, optional): Namespace for the database. caller_name (str, optional): identity of the caller ("my_framework") If passing a client, its caller is used as fallback caller_version (str, optional): version of the caller code ("1.0.3") If passing a client, its caller is used as fallback """ # Check for presence of the Astra DB object if astra_db is None: if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") astra_db = AstraDB( token=token, api_endpoint=api_endpoint, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ) else: # if astra_db passed, copy and apply possible overrides astra_db = astra_db.copy( token=token, api_endpoint=api_endpoint, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ) # Set the remaining instance attributes self.astra_db = astra_db self.caller_name: Optional[str] = self.astra_db.caller_name self.caller_version: Optional[str] = self.astra_db.caller_version self.collection_name = collection_name self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' def __eq__(self, other: Any) -> bool: if isinstance(other, AstraDBCollection): return all( [ self.collection_name == other.collection_name, self.astra_db == other.astra_db, self.caller_name == other.caller_name, self.caller_version == other.caller_version, ] ) else: return False def copy( self, *, collection_name: Optional[str] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> AstraDBCollection: return AstraDBCollection( collection_name=collection_name or self.collection_name, astra_db=self.astra_db.copy( token=token, api_endpoint=api_endpoint, api_path=api_path, api_version=api_version, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ), caller_name=caller_name or self.caller_name, caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( collection_name=self.collection_name, astra_db=self.astra_db.to_async(), caller_name=self.caller_name, caller_version=self.caller_version, ) def set_caller( self, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, caller_version=caller_version, ) self.caller_name = caller_name self.caller_version = caller_version def _request( self, method: str = http_methods.POST, path: Optional[str] = None, json_data: Optional[Dict[str, Any]] = None, url_params: Optional[Dict[str, Any]] = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: direct_response = api_request( client=self.client, base_url=self.astra_db.base_url, auth_header=DEFAULT_AUTH_HEADER, token=self.astra_db.token, method=method, json_data=normalize_for_api(json_data), url_params=url_params, path=path, skip_error_check=skip_error_check, caller_name=self.caller_name, caller_version=self.caller_version, timeout=to_httpx_timeout(timeout_info), ) response = restore_from_api(direct_response) return response def post_raw_request( self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, path=self.base_path, json_data=body, timeout_info=timeout_info, ) def _get( self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Optional[API_RESPONSE]: full_path = f"{self.base_path}/{path}" if path else self.base_path response = self._request( method=http_methods.GET, path=full_path, url_params=options, timeout_info=timeout_info, ) if isinstance(response, dict): return response return None def _put( self, path: Optional[str] = None, document: Optional[API_RESPONSE] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path response = self._request( method=http_methods.PUT, path=full_path, json_data=document, timeout_info=timeout_info, ) return response def _post( self, path: Optional[str] = None, document: Optional[API_DOC] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path response = self._request( method=http_methods.POST, path=full_path, json_data=document, timeout_info=timeout_info, ) return response def _recast_as_sort_projection( self, vector: List[float], fields: Optional[List[str]] = None ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular 'find'-like API calls (with basic validation as well). """ # Must pass a vector if not vector: raise ValueError("Must pass a vector") # Edge case for field selection if fields and "$similarity" in fields: raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter sort: Dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None # (or it will devour $similarity away in the API response) if fields is not None and len(fields) > 0: projection = {f: 1 for f in fields} else: projection = None return sort, projection def get( self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None ) -> Optional[API_RESPONSE]: """ Retrieve a document from the collection by its path. Args: path (str, optional): The path of the document to retrieve. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The retrieved document. """ return self._get(path=path, timeout_info=timeout_info) def find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find documents in the collection that match the given filter. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The query response containing matched documents. """ json_query = make_payload( top_level="find", filter=filter, projection=projection, options=options, sort=sort, ) response = self._post(document=json_query, timeout_info=timeout_info) return response def vector_find( self, vector: List[float], *, limit: int, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> List[API_DOC]: """ Perform a vector-based search in the collection. Args: vector (list): The vector to search with. limit (int): The maximum number of documents to return. filter (dict, optional): Criteria to filter documents. fields (list, optional): Specifies the fields to return. include_similarity (bool, optional): Whether to include similarity score in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: list: A list of documents matching the vector search criteria. """ # Must pass a limit if not limit: raise ValueError("Must pass a limit") # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( convert_vector_to_floats(vector), fields=fields, ) # Call the underlying find() method to search raw_find_result = self.find( filter=filter, projection=projection, sort=sort, options={ "limit": limit, "includeSimilarity": include_similarity, }, timeout_info=timeout_info, ) return cast(List[API_DOC], raw_find_result["data"]["documents"]) @staticmethod def paginate( *, request_method: PaginableRequestMethod, options: Optional[Dict[str, Any]], prefetched: Optional[int] = None, ) -> Generator[API_DOC, None, None]: """ Generate paginated results for a given database query method. Args: request_method (function): The database query method to paginate. options (dict, optional): Options for the database query. prefetched (int, optional): Number of pre-fetched documents. Yields: dict: The next document in the paginated result set. """ _options = options or {} response0 = request_method(options=_options) next_page_state = response0["data"]["nextPageState"] options0 = _options if next_page_state is not None and prefetched: def queued_paginate( queue: Queue[Optional[API_DOC]], request_method: PaginableRequestMethod, options: Optional[Dict[str, Any]], ) -> None: try: for row in AstraDBCollection.paginate( request_method=request_method, options=options ): queue.put(row) finally: queue.put(None) queue: Queue[Optional[API_DOC]] = Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} t = threading.Thread( target=queued_paginate, args=(queue, request_method, options1) ) t.start() for document in response0["data"]["documents"]: yield document doc = queue.get() while doc is not None: yield doc doc = queue.get() t.join() else: for document in response0["data"]["documents"]: yield document while next_page_state is not None and not prefetched: options1 = {**options0, **{"pageState": next_page_state}} response1 = request_method(options=options1) for document in response1["data"]["documents"]: yield document next_page_state = response1["data"]["nextPageState"] def paginated_find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, prefetched: Optional[int] = None, timeout_info: TimeoutInfoWideType = None, ) -> Iterator[API_DOC]: """ Perform a paginated search in the collection. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. prefetched (int, optional): Number of pre-fetched documents. timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This is a paginated method, that issues several requests as it needs more data. This parameter controls a single request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: generator: A generator yielding documents in the paginated result set. """ partialed_find = partial( self.find, filter=filter, projection=projection, sort=sort, timeout_info=timeout_info, ) return self.paginate( request_method=partialed_find, options=options, prefetched=prefetched, ) def pop( self, filter: Dict[str, Any], pop: Dict[str, Any], options: Dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Pop the last data in the tags array Args: filter (dict): Criteria to identify the document to update. pop (dict): The pop to apply to the tags. options (dict): Additional options for the update operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The original document before the update. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update={"$pop": pop}, options=options, ) response = self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response def push( self, filter: Dict[str, Any], push: Dict[str, Any], options: Dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Push new data to the tags array Args: filter (dict): Criteria to identify the document to update. push (dict): The push to apply to the tags. options (dict): Additional options for the update operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the update operation. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update={"$push": push}, options=options, ) response = self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response def find_one_and_replace( self, replacement: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and replace it. Args: replacement (dict): The new document to replace the existing one. filter (dict, optional): Criteria to filter documents. sort (dict, optional): Specifies the order in which to find the document. options (dict, optional): Additional options for the operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and replace operation. """ json_query = make_payload( top_level="findOneAndReplace", filter=filter, projection=projection, replacement=replacement, options=options, sort=sort, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def vector_find_one_and_replace( self, vector: List[float], replacement: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search and replace the first matched document. Args: vector (dict): The vector to search with. replacement (dict): The new document to replace the existing one. filter (dict, optional): Criteria to filter documents. fields (list, optional): Specifies the fields to return in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: either the matched document or None if nothing found """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( convert_vector_to_floats(vector), fields=fields, ) # Call the underlying find() method to search raw_find_result = self.find_one_and_replace( replacement=replacement, filter=filter, projection=projection, sort=sort, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) def find_one_and_update( self, update: Dict[str, Any], sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and update it. Args: update (dict): The update to apply to the document. sort (dict, optional): Specifies the order in which to find the document. filter (dict, optional): Criteria to filter documents. options (dict, optional): Additional options for the operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and update operation. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update=update, options=options, sort=sort, projection=projection, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def vector_find_one_and_update( self, vector: List[float], update: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search and update the first matched document. Args: vector (list): The vector to search with. update (dict): The update to apply to the matched document. filter (dict, optional): Criteria to filter documents before applying the vector search. fields (list, optional): Specifies the fields to return in the updated document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: The result of the vector-based find and update operation, or None if nothing found """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( convert_vector_to_floats(vector), fields=fields, ) # Call the underlying find() method to search raw_find_result = self.find_one_and_update( update=update, filter=filter, sort=sort, projection=projection, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) def find_one_and_delete( self, sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and delete it. Args: sort (dict, optional): Specifies the order in which to find the document. filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and delete operation. """ json_query = make_payload( top_level="findOneAndDelete", filter=filter, sort=sort, projection=projection, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def count_documents( self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). Args: filter (dict, defaults to {}): Criteria to filter documents. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: the response, either {"status": {"count": }} or {"errors": [...]} """ json_query = make_payload( top_level="countDocuments", filter=filter, ) response = self._post(document=json_query, timeout_info=timeout_info) return response def find_one( self, filter: Optional[Dict[str, Any]] = {}, projection: Optional[Dict[str, Any]] = {}, sort: Optional[Dict[str, Any]] = {}, options: Optional[Dict[str, Any]] = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document in the collection. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return the document. options (dict, optional): Additional options for the query. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: the response, either {"data": {"document": }} or {"data": {"document": None}} depending on whether a matching document is found or not. """ json_query = make_payload( top_level="findOne", filter=filter, projection=projection, options=options, sort=sort, ) response = self._post(document=json_query, timeout_info=timeout_info) return response def vector_find_one( self, vector: List[float], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search to find a single document in the collection. Args: vector (list): The vector to search with. filter (dict, optional): Additional criteria to filter documents. fields (list, optional): Specifies the fields to return in the result. include_similarity (bool, optional): Whether to include similarity score in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: The found document or None if no matching document is found. """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( convert_vector_to_floats(vector), fields=fields, ) # Call the underlying find() method to search raw_find_result = self.find_one( filter=filter, projection=projection, sort=sort, options={"includeSimilarity": include_similarity}, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) def insert_one( self, document: API_DOC, failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Insert a single document into the collection. Args: document (dict): The document to insert. failures_allowed (bool): Whether to allow failures in the insert operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the insert operation. """ json_query = make_payload(top_level="insertOne", document=document) response = self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, skip_error_check=failures_allowed, timeout_info=timeout_info, ) return response def insert_many( self, documents: List[API_DOC], options: Optional[Dict[str, Any]] = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Insert multiple documents into the collection. Args: documents (list): A list of documents to insert. options (dict, optional): Additional options for the insert operation. partial_failures_allowed (bool, optional): Whether to allow partial failures through the insertion (i.e. on some documents). timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the insert operation. """ json_query = make_payload( top_level="insertMany", documents=documents, options=options ) # Send the data response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, skip_error_check=partial_failures_allowed, timeout_info=timeout_info, ) return response def chunked_insert_many( self, documents: List[API_DOC], options: Optional[Dict[str, Any]] = None, partial_failures_allowed: bool = False, chunk_size: int = MAX_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, ) -> List[Union[API_RESPONSE, Exception]]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. Args: documents (list): A list of documents to insert. options (dict, optional): Additional options for the insert operation. partial_failures_allowed (bool, optional): Whether to allow partial failures in the chunk. Should be used combined with options={"ordered": False} in most cases. chunk_size (int, optional): Override the default insertion chunk size. concurrency (int, optional): The number of concurrent chunk insertions. Default is no concurrency. timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This method runs a number of HTTP requests as it works on chunked data. The timeout refers to each individual such request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: list: The responses from the database after the chunked insert operation. This is a list of individual responses from the API: the caller will need to inspect them all, e.g. to collate the inserted IDs. """ results: List[Union[API_RESPONSE, Exception]] = [] # Raise a warning if ordered and concurrency if options and options.get("ordered") is True and concurrency > 1: logger.warning( "Using ordered insert with concurrency may lead to unexpected results." ) # If we have concurrency as 1, don't use a thread pool if concurrency == 1: # Split the documents into chunks for i in range(0, len(documents), chunk_size): try: results.append( self.insert_many( documents[i : i + chunk_size], options, partial_failures_allowed, timeout_info=timeout_info, ) ) except APIRequestError as e: if partial_failures_allowed: results.append(e) else: raise e return results # Perform the bulk insert with concurrency otherwise with ThreadPoolExecutor(max_workers=concurrency) as executor: # Submit the jobs futures = [ executor.submit( self.insert_many, documents[i : i + chunk_size], options, partial_failures_allowed, timeout_info=timeout_info, ) for i in range(0, len(documents), chunk_size) ] # Collect the results for future in futures: try: results.append(future.result()) except APIRequestError as e: if partial_failures_allowed: results.append(e) else: raise e return results def update_one( self, filter: Dict[str, Any], update: Dict[str, Any], sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Update a single document in the collection. Args: filter (dict): Criteria to identify the document to update. update (dict): The update to apply to the document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = make_payload( top_level="updateOne", filter=filter, update=update, sort=sort, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def update_many( self, filter: Dict[str, Any], update: Dict[str, Any], options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Updates multiple documents in the collection. Args: filter (dict): Criteria to identify the document to update. update (dict): The update to apply to the document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = make_payload( top_level="updateMany", filter=filter, update=update, options=options, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def replace( self, path: str, document: API_DOC, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Replace a document in the collection. Args: path (str): The path to the document to replace. document (dict): The new document to replace the existing one. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the replace operation. """ return self._put(path=path, document=document, timeout_info=timeout_info) @deprecation.deprecated( # type: ignore deprecated_in="0.7.0", removed_in="1.0.0", current_version=__version__, details="Use the 'delete_one' method instead", ) def delete(self, id: str, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: return self.delete_one(id, timeout_info=timeout_info) def delete_one( self, id: str, sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete a single document from the collection based on its ID. Args: id (str): The ID of the document to delete. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = make_payload( top_level="deleteOne", filter={"_id": id}, sort=sort, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def delete_one_by_predicate( self, filter: Dict[str, Any], sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete a single document from the collection based on a filter clause Args: filter: any filter dictionary timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = make_payload( top_level="deleteOne", filter=filter, sort=sort, ) response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response def delete_many( self, filter: Dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete many documents from the collection based on a filter condition Args: filter (dict): Criteria to identify the documents to delete. skip_error_check (bool): whether to ignore the check for API error and return the response untouched. Default is False. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = { "deleteMany": { "filter": filter, } } response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, skip_error_check=skip_error_check, timeout_info=timeout_info, ) return response def chunked_delete_many( self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> List[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. Args: filter (dict): Criteria to identify the documents to delete. timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This method runs a number of HTTP requests as it works on a pagination basis. The timeout refers to each individual such request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: List[dict]: The responses from the database from all the calls """ responses = [] must_proceed = True while must_proceed: dm_response = self.delete_many(filter=filter, timeout_info=timeout_info) responses.append(dm_response) must_proceed = dm_response.get("status", {}).get("moreData", False) return responses def clear(self, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: """ Clear the collection, deleting all documents Args: timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database. """ clear_response = self.delete_many(filter={}, timeout_info=timeout_info) if clear_response.get("status", {}).get("deletedCount") != -1: raise ValueError( f"Could not issue a clear-collection API command (response: {json.dumps(clear_response)})." ) return clear_response def delete_subdocument( self, id: str, subdoc: str, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Delete a subdocument or field from a document in the collection. Args: id (str): The ID of the document containing the subdocument. subdoc (str): The key of the subdocument or field to remove. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = { "findOneAndUpdate": { "filter": {"_id": id}, "update": {"$unset": {subdoc: ""}}, } } response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response @deprecation.deprecated( # type: ignore deprecated_in="0.7.0", removed_in="1.0.0", current_version=__version__, details="Use the 'upsert_one' method instead", ) def upsert( self, document: API_DOC, timeout_info: TimeoutInfoWideType = None ) -> str: return self.upsert_one(document, timeout_info=timeout_info) def upsert_one( self, document: API_DOC, timeout_info: TimeoutInfoWideType = None ) -> str: """ Emulate an upsert operation for a single document in the collection. This method attempts to insert the document. If a document with the same _id exists, it updates the existing document. Args: document (dict): The document to insert or update. timeout_info: a float, or a TimeoutInfo dict, for the HTTP requests. This method may issue one or two requests, depending on what is detected on DB. This timeout controls each HTTP request individually. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: str: The _id of the inserted or updated document. """ # Build the payload for the insert attempt result = self.insert_one( document, failures_allowed=True, timeout_info=timeout_info ) # If the call failed because of preexisting doc, then we replace it if "errors" in result: if ( "errorCode" in result["errors"][0] and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" ): # Now we attempt the update result = self.find_one_and_replace( replacement=document, filter={"_id": document["_id"]}, timeout_info=timeout_info, ) upserted_id = cast(str, result["data"]["document"]["_id"]) else: raise ValueError(result) else: if result.get("status", {}).get("insertedIds", []): upserted_id = cast(str, result["status"]["insertedIds"][0]) else: raise ValueError("Unexplained empty insertedIds from API") return upserted_id def upsert_many( self, documents: list[API_DOC], concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> List[Union[str, Exception]]: """ Emulate an upsert operation for multiple documents in the collection. This method attempts to insert the documents. If a document with the same _id exists, it updates the existing document. Args: documents (List[dict]): The documents to insert or update. concurrency (int, optional): The number of concurrent upserts. partial_failures_allowed (bool, optional): Whether to allow partial failures in the batch. timeout_info: a float, or a TimeoutInfo dict, for each HTTP request. This method issues a separate HTTP request for each document to insert: the timeout controls each such request individually. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. """ results: List[Union[str, Exception]] = [] # If concurrency is 1, no need for thread pool if concurrency == 1: for document in documents: try: results.append(self.upsert_one(document, timeout_info=timeout_info)) except Exception as e: results.append(e) return results # Perform the bulk upsert with concurrency with ThreadPoolExecutor(max_workers=concurrency) as executor: # Submit the jobs futures = [ executor.submit(self.upsert, document, timeout_info=timeout_info) for document in documents ] # Collect the results for future in futures: try: results.append(future.result()) except Exception as e: if partial_failures_allowed: results.append(e) else: raise e return results class AsyncAstraDBCollection: def __init__( self, collection_name: str, astra_db: Optional[AsyncAstraDB] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: """ Initialize an AstraDBCollection instance. Args: collection_name (str): The name of the collection. astra_db (AstraDB, optional): An instance of Astra DB. token (str, optional): Authentication token for Astra DB. api_endpoint (str, optional): API endpoint URL. namespace (str, optional): Namespace for the database. caller_name (str, optional): identity of the caller ("my_framework") If passing a client, its caller is used as fallback caller_version (str, optional): version of the caller code ("1.0.3") If passing a client, its caller is used as fallback """ # Check for presence of the Astra DB object if astra_db is None: if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") astra_db = AsyncAstraDB( token=token, api_endpoint=api_endpoint, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ) else: # if astra_db passed, copy and apply possible overrides astra_db = astra_db.copy( token=token, api_endpoint=api_endpoint, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ) # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db self.caller_name: Optional[str] = self.astra_db.caller_name self.caller_version: Optional[str] = self.astra_db.caller_version self.client = astra_db.client self.collection_name = collection_name self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' def __eq__(self, other: Any) -> bool: if isinstance(other, AsyncAstraDBCollection): return all( [ self.collection_name == other.collection_name, self.astra_db == other.astra_db, self.caller_name == other.caller_name, self.caller_version == other.caller_version, ] ) else: return False def copy( self, *, collection_name: Optional[str] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( collection_name=collection_name or self.collection_name, astra_db=self.astra_db.copy( token=token, api_endpoint=api_endpoint, api_path=api_path, api_version=api_version, namespace=namespace, caller_name=caller_name, caller_version=caller_version, ), caller_name=caller_name or self.caller_name, caller_version=caller_version or self.caller_version, ) def set_caller( self, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, caller_version=caller_version, ) self.caller_name = caller_name self.caller_version = caller_version def to_sync(self) -> AstraDBCollection: return AstraDBCollection( collection_name=self.collection_name, astra_db=self.astra_db.to_sync(), caller_name=self.caller_name, caller_version=self.caller_version, ) async def _request( self, method: str = http_methods.POST, path: Optional[str] = None, json_data: Optional[Dict[str, Any]] = None, url_params: Optional[Dict[str, Any]] = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, **kwargs: Any, ) -> API_RESPONSE: adirect_response = await async_api_request( client=self.client, base_url=self.astra_db.base_url, auth_header=DEFAULT_AUTH_HEADER, token=self.astra_db.token, method=method, json_data=normalize_for_api(json_data), url_params=url_params, path=path, skip_error_check=skip_error_check, caller_name=self.caller_name, caller_version=self.caller_version, timeout=to_httpx_timeout(timeout_info), ) response = restore_from_api(adirect_response) return response async def post_raw_request( self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, path=self.base_path, json_data=body, timeout_info=timeout_info, ) async def _get( self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Optional[API_RESPONSE]: full_path = f"{self.base_path}/{path}" if path else self.base_path response = await self._request( method=http_methods.GET, path=full_path, url_params=options, timeout_info=timeout_info, ) if isinstance(response, dict): return response return None async def _put( self, path: Optional[str] = None, document: Optional[API_RESPONSE] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path response = await self._request( method=http_methods.PUT, path=full_path, json_data=document, timeout_info=timeout_info, ) return response async def _post( self, path: Optional[str] = None, document: Optional[API_DOC] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path response = await self._request( method=http_methods.POST, path=full_path, json_data=document, timeout_info=timeout_info, ) return response def _recast_as_sort_projection( self, vector: List[float], fields: Optional[List[str]] = None ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular 'find'-like API calls (with basic validation as well). """ # Must pass a vector if not vector: raise ValueError("Must pass a vector") # Edge case for field selection if fields and "$similarity" in fields: raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter sort: Dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None # (or it will devour $similarity away in the API response) if fields is not None and len(fields) > 0: projection = {f: 1 for f in fields} else: projection = None return sort, projection async def get( self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None ) -> Optional[API_RESPONSE]: """ Retrieve a document from the collection by its path. Args: path (str, optional): The path of the document to retrieve. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The retrieved document. """ return await self._get(path=path, timeout_info=timeout_info) async def find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find documents in the collection that match the given filter. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The query response containing matched documents. """ json_query = make_payload( top_level="find", filter=filter, projection=projection, options=options, sort=sort, ) response = await self._post(document=json_query, timeout_info=timeout_info) return response async def vector_find( self, vector: List[float], *, limit: int, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> List[API_DOC]: """ Perform a vector-based search in the collection. Args: vector (list): The vector to search with. limit (int): The maximum number of documents to return. filter (dict, optional): Criteria to filter documents. fields (list, optional): Specifies the fields to return. include_similarity (bool, optional): Whether to include similarity score in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: list: A list of documents matching the vector search criteria. """ # Must pass a limit if not limit: raise ValueError("Must pass a limit") # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( vector, fields=fields, ) # Call the underlying find() method to search raw_find_result = await self.find( filter=filter, projection=projection, sort=sort, options={ "limit": limit, "includeSimilarity": include_similarity, }, timeout_info=timeout_info, ) return cast(List[API_DOC], raw_find_result["data"]["documents"]) @staticmethod async def paginate( *, request_method: AsyncPaginableRequestMethod, options: Optional[Dict[str, Any]], prefetched: Optional[int] = None, timeout_info: TimeoutInfoWideType = None, ) -> AsyncGenerator[API_DOC, None]: """ Generate paginated results for a given database query method. Args: request_method (function): The database query method to paginate. options (dict, optional): Options for the database query. prefetched (int, optional): Number of pre-fetched documents. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Yields: dict: The next document in the paginated result set. """ _options = options or {} response0 = await request_method(options=_options) next_page_state = response0["data"]["nextPageState"] options0 = _options if next_page_state is not None and prefetched: async def queued_paginate( queue: asyncio.Queue[Optional[API_DOC]], request_method: AsyncPaginableRequestMethod, options: Optional[Dict[str, Any]], ) -> None: try: async for doc in AsyncAstraDBCollection.paginate( request_method=request_method, options=options ): await queue.put(doc) finally: await queue.put(None) queue: asyncio.Queue[Optional[API_DOC]] = asyncio.Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} asyncio.create_task(queued_paginate(queue, request_method, options1)) for document in response0["data"]["documents"]: yield document doc = await queue.get() while doc is not None: yield doc doc = await queue.get() else: for document in response0["data"]["documents"]: yield document while next_page_state is not None: options1 = {**options0, **{"pageState": next_page_state}} response1 = await request_method(options=options1) for document in response1["data"]["documents"]: yield document next_page_state = response1["data"]["nextPageState"] def paginated_find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, prefetched: Optional[int] = None, timeout_info: TimeoutInfoWideType = None, ) -> AsyncIterator[API_DOC]: """ Perform a paginated search in the collection. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return matching documents. options (dict, optional): Additional options for the query. prefetched (int, optional): Number of pre-fetched documents timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This is a paginated method, that issues several requests as it needs more data. This parameter controls a single request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: generator: A generator yielding documents in the paginated result set. """ partialed_find = partial( self.find, filter=filter, projection=projection, sort=sort, timeout_info=timeout_info, ) return self.paginate( request_method=partialed_find, options=options, prefetched=prefetched, ) async def pop( self, filter: Dict[str, Any], pop: Dict[str, Any], options: Dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Pop the last data in the tags array Args: filter (dict): Criteria to identify the document to update. pop (dict): The pop to apply to the tags. options (dict): Additional options for the update operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The original document before the update. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update={"$pop": pop}, options=options, ) response = await self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response async def push( self, filter: Dict[str, Any], push: Dict[str, Any], options: Dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Push new data to the tags array Args: filter (dict): Criteria to identify the document to update. push (dict): The push to apply to the tags. options (dict): Additional options for the update operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the update operation. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update={"$push": push}, options=options, ) response = await self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response async def find_one_and_replace( self, replacement: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and replace it. Args: replacement (dict): The new document to replace the existing one. filter (dict, optional): Criteria to filter documents. sort (dict, optional): Specifies the order in which to find the document. options (dict, optional): Additional options for the operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and replace operation. """ json_query = make_payload( top_level="findOneAndReplace", filter=filter, projection=projection, replacement=replacement, options=options, sort=sort, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def vector_find_one_and_replace( self, vector: List[float], replacement: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search and replace the first matched document. Args: vector (dict): The vector to search with. replacement (dict): The new document to replace the existing one. filter (dict, optional): Criteria to filter documents. fields (list, optional): Specifies the fields to return in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: either the matched document or None if nothing found """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( vector, fields=fields, ) # Call the underlying find() method to search raw_find_result = await self.find_one_and_replace( replacement=replacement, filter=filter, projection=projection, sort=sort, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) async def find_one_and_update( self, update: Dict[str, Any], sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and update it. Args: sort (dict, optional): Specifies the order in which to find the document. update (dict): The update to apply to the document. filter (dict, optional): Criteria to filter documents. options (dict, optional): Additional options for the operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and update operation. """ json_query = make_payload( top_level="findOneAndUpdate", filter=filter, update=update, options=options, sort=sort, projection=projection, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def vector_find_one_and_update( self, vector: List[float], update: Dict[str, Any], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search and update the first matched document. Args: vector (list): The vector to search with. update (dict): The update to apply to the matched document. filter (dict, optional): Criteria to filter documents before applying the vector search. fields (list, optional): Specifies the fields to return in the updated document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: The result of the vector-based find and update operation, or None if nothing found """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( vector, fields=fields, ) # Call the underlying find() method to search raw_find_result = await self.find_one_and_update( update=update, filter=filter, sort=sort, projection=projection, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) async def find_one_and_delete( self, sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document and delete it. Args: sort (dict, optional): Specifies the order in which to find the document. filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The result of the find and delete operation. """ json_query = make_payload( top_level="findOneAndDelete", filter=filter, sort=sort, projection=projection, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def count_documents( self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). Args: filter (dict, defaults to {}): Criteria to filter documents. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: the response, either {"status": {"count": }} or {"errors": [...]} """ json_query = make_payload( top_level="countDocuments", filter=filter, ) response = await self._post(document=json_query, timeout_info=timeout_info) return response async def find_one( self, filter: Optional[Dict[str, Any]] = {}, projection: Optional[Dict[str, Any]] = {}, sort: Optional[Dict[str, Any]] = {}, options: Optional[Dict[str, Any]] = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Find a single document in the collection. Args: filter (dict, optional): Criteria to filter documents. projection (dict, optional): Specifies the fields to return. sort (dict, optional): Specifies the order in which to return the document. options (dict, optional): Additional options for the query. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: the response, either {"data": {"document": }} or {"data": {"document": None}} depending on whether a matching document is found or not. """ json_query = make_payload( top_level="findOne", filter=filter, projection=projection, options=options, sort=sort, ) response = await self._post(document=json_query, timeout_info=timeout_info) return response async def vector_find_one( self, vector: List[float], *, filter: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> Union[API_DOC, None]: """ Perform a vector-based search to find a single document in the collection. Args: vector (list): The vector to search with. filter (dict, optional): Additional criteria to filter documents. fields (list, optional): Specifies the fields to return in the result. include_similarity (bool, optional): Whether to include similarity score in the result. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict or None: The found document or None if no matching document is found. """ # Pre-process the included arguments sort, projection = self._recast_as_sort_projection( vector, fields=fields, ) # Call the underlying find() method to search raw_find_result = await self.find_one( filter=filter, projection=projection, sort=sort, options={"includeSimilarity": include_similarity}, timeout_info=timeout_info, ) return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) async def insert_one( self, document: API_DOC, failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Insert a single document into the collection. Args: document (dict): The document to insert. failures_allowed (bool): Whether to allow failures in the insert operation. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the insert operation. """ json_query = make_payload(top_level="insertOne", document=document) response = await self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, skip_error_check=failures_allowed, timeout_info=timeout_info, ) return response async def insert_many( self, documents: List[API_DOC], options: Optional[Dict[str, Any]] = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Insert multiple documents into the collection. Args: documents (list): A list of documents to insert. options (dict, optional): Additional options for the insert operation. partial_failures_allowed (bool, optional): Whether to allow partial failures through the insertion (i.e. on some documents). timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the insert operation. """ json_query = make_payload( top_level="insertMany", documents=documents, options=options ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, skip_error_check=partial_failures_allowed, timeout_info=timeout_info, ) return response async def chunked_insert_many( self, documents: List[API_DOC], options: Optional[Dict[str, Any]] = None, partial_failures_allowed: bool = False, chunk_size: int = MAX_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, ) -> List[Union[API_RESPONSE, Exception]]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. Args: documents (list): A list of documents to insert. options (dict, optional): Additional options for the insert operation. partial_failures_allowed (bool, optional): Whether to allow partial failures in the chunk. Should be used combined with options={"ordered": False} in most cases. chunk_size (int, optional): Override the default insertion chunk size. concurrency (int, optional): The number of concurrent chunk insertions. Default is no concurrency. timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This method runs a number of HTTP requests as it works on chunked data. The timeout refers to each individual such request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: list: The responses from the database after the chunked insert operation. This is a list of individual responses from the API: the caller will need to inspect them all, e.g. to collate the inserted IDs. """ sem = asyncio.Semaphore(concurrency) async def concurrent_insert_many( docs: List[API_DOC], index: int, partial_failures_allowed: bool, ) -> Union[API_RESPONSE, Exception]: async with sem: logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") try: return await self.insert_many( documents=docs, options=options, partial_failures_allowed=partial_failures_allowed, timeout_info=timeout_info, ) except APIRequestError as e: if partial_failures_allowed: return e else: raise e if concurrency > 1: tasks = [ asyncio.create_task( concurrent_insert_many( documents[i : i + chunk_size], i, partial_failures_allowed ) ) for i in range(0, len(documents), chunk_size) ] results = await asyncio.gather(*tasks, return_exceptions=False) else: # this ensures the expectation of # "sequential strictly obeys fail-fast if ordered and concurrency==1" results = [ await concurrent_insert_many( documents[i : i + chunk_size], i, partial_failures_allowed ) for i in range(0, len(documents), chunk_size) ] return results async def update_one( self, filter: Dict[str, Any], update: Dict[str, Any], sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Update a single document in the collection. Args: filter (dict): Criteria to identify the document to update. update (dict): The update to apply to the document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = make_payload( top_level="updateOne", filter=filter, update=update, sort=sort, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def update_many( self, filter: Dict[str, Any], update: Dict[str, Any], options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Updates multiple documents in the collection. Args: filter (dict): Criteria to identify the document to update. update (dict): The update to apply to the document. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = make_payload( top_level="updateMany", filter=filter, update=update, options=options, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def replace( self, path: str, document: API_DOC, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Replace a document in the collection. Args: path (str): The path to the document to replace. document (dict): The new document to replace the existing one. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the replace operation. """ return await self._put(path=path, document=document, timeout_info=timeout_info) async def delete_one( self, id: str, sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete a single document from the collection based on its ID. Args: id (str): The ID of the document to delete. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = make_payload( top_level="deleteOne", filter={"_id": id}, sort=sort, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def delete_one_by_predicate( self, filter: Dict[str, Any], sort: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete a single document from the collection based on a filter clause Args: filter: any filter dictionary timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = make_payload( top_level="deleteOne", filter=filter, sort=sort, ) response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response async def delete_many( self, filter: Dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Delete many documents from the collection based on a filter condition Args: filter (dict): Criteria to identify the documents to delete. skip_error_check (bool): whether to ignore the check for API error and return the response untouched. Default is False. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the delete operation. """ json_query = { "deleteMany": { "filter": filter, } } response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, skip_error_check=skip_error_check, timeout_info=timeout_info, ) return response async def chunked_delete_many( self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> List[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. Args: filter (dict): Criteria to identify the documents to delete. timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. This method runs a number of HTTP requests as it works on a pagination basis. The timeout refers to each individual such request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: List[dict]: The responses from the database from all the calls """ responses = [] must_proceed = True while must_proceed: dm_response = await self.delete_many( filter=filter, timeout_info=timeout_info ) responses.append(dm_response) must_proceed = dm_response.get("status", {}).get("moreData", False) return responses async def clear(self, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: """ Clear the collection, deleting all documents Args: timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database. """ clear_response = await self.delete_many(filter={}, timeout_info=timeout_info) if clear_response.get("status", {}).get("deletedCount") != -1: raise ValueError( f"Could not issue a clear-collection API command (response: {json.dumps(clear_response)})." ) return clear_response async def delete_subdocument( self, id: str, subdoc: str, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Delete a subdocument or field from a document in the collection. Args: id (str): The ID of the document containing the subdocument. subdoc (str): The key of the subdocument or field to remove. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database after the update operation. """ json_query = { "findOneAndUpdate": { "filter": {"_id": id}, "update": {"$unset": {subdoc: ""}}, } } response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data=json_query, timeout_info=timeout_info, ) return response @deprecation.deprecated( # type: ignore deprecated_in="0.7.0", removed_in="1.0.0", current_version=__version__, details="Use the 'upsert_one' method instead", ) async def upsert( self, document: API_DOC, timeout_info: TimeoutInfoWideType = None ) -> str: return await self.upsert_one(document, timeout_info=timeout_info) async def upsert_one( self, document: API_DOC, timeout_info: TimeoutInfoWideType = None, ) -> str: """ Emulate an upsert operation for a single document in the collection. This method attempts to insert the document. If a document with the same _id exists, it updates the existing document. Args: document (dict): The document to insert or update. timeout_info: a float, or a TimeoutInfo dict, for the HTTP requests. This method may issue one or two requests, depending on what is detected on DB. This timeout controls each HTTP request individually. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: str: The _id of the inserted or updated document. """ # Build the payload for the insert attempt result = await self.insert_one( document, failures_allowed=True, timeout_info=timeout_info ) # If the call failed because of preexisting doc, then we replace it if "errors" in result: if ( "errorCode" in result["errors"][0] and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" ): # Now we attempt the update result = await self.find_one_and_replace( replacement=document, filter={"_id": document["_id"]}, timeout_info=timeout_info, ) upserted_id = cast(str, result["data"]["document"]["_id"]) else: raise ValueError(result) else: if result.get("status", {}).get("insertedIds", []): upserted_id = cast(str, result["status"]["insertedIds"][0]) else: raise ValueError("Unexplained empty insertedIds from API") return upserted_id async def upsert_many( self, documents: list[API_DOC], concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> List[Union[str, Exception]]: """ Emulate an upsert operation for multiple documents in the collection. This method attempts to insert the documents. If a document with the same _id exists, it updates the existing document. Args: documents (List[dict]): The documents to insert or update. concurrency (int, optional): The number of concurrent upserts. partial_failures_allowed (bool, optional): Whether to allow partial failures in the batch. timeout_info: a float, or a TimeoutInfo dict, for each HTTP request. This method issues a separate HTTP request for each document to insert: the timeout controls each such request individually. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. """ sem = asyncio.Semaphore(concurrency) async def concurrent_upsert(doc: API_DOC) -> str: async with sem: return await self.upsert_one(document=doc, timeout_info=timeout_info) tasks = [asyncio.create_task(concurrent_upsert(doc)) for doc in documents] results = await asyncio.gather( *tasks, return_exceptions=partial_failures_allowed ) for result in results: if isinstance(result, BaseException) and not isinstance(result, Exception): raise result return results # type: ignore class AstraDB: # Initialize the shared httpx client as a class attribute client = httpx.Client() def __init__( self, token: str, api_endpoint: str, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: """ Initialize an Astra DB instance. Args: token (str): Authentication token for Astra DB. api_endpoint (str): API endpoint URL. api_path (str, optional): used to override default URI construction api_version (str, optional): to override default URI construction namespace (str, optional): Namespace for the database. caller_name (str, optional): identity of the caller ("my_framework") caller_version (str, optional): version of the caller code ("1.0.3") """ self.caller_name = caller_name self.caller_version = caller_version if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") if namespace is None: logger.info( f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" ) namespace = DEFAULT_KEYSPACE_NAME # Store the API token self.token = token self.api_endpoint = api_endpoint # Set the Base URL for the API calls self.base_url = self.api_endpoint.strip("/") # Set the API version and path from the call self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") # Set the namespace self.namespace = namespace # Finally, construct the full base path self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" def __repr__(self) -> str: return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' def __eq__(self, other: Any) -> bool: if isinstance(other, AstraDB): # work on the "normalized" quantities (stripped, etc) return all( [ self.token == other.token, self.base_url == other.base_url, self.base_path == other.base_path, self.caller_name == other.caller_name, self.caller_version == other.caller_version, ] ) else: return False def copy( self, *, token: Optional[str] = None, api_endpoint: Optional[str] = None, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> AstraDB: return AstraDB( token=token or self.token, api_endpoint=api_endpoint or self.base_url, api_path=api_path or self.api_path, api_version=api_version or self.api_version, namespace=namespace or self.namespace, caller_name=caller_name or self.caller_name, caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDB: return AsyncAstraDB( token=self.token, api_endpoint=self.base_url, api_path=self.api_path, api_version=self.api_version, namespace=self.namespace, caller_name=self.caller_name, caller_version=self.caller_version, ) def set_caller( self, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version def _request( self, method: str = http_methods.POST, path: Optional[str] = None, json_data: Optional[Dict[str, Any]] = None, url_params: Optional[Dict[str, Any]] = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: direct_response = api_request( client=self.client, base_url=self.base_url, auth_header=DEFAULT_AUTH_HEADER, token=self.token, method=method, json_data=normalize_for_api(json_data), url_params=url_params, path=path, skip_error_check=skip_error_check, caller_name=self.caller_name, caller_version=self.caller_version, timeout=to_httpx_timeout(timeout_info), ) response = restore_from_api(direct_response) return response def post_raw_request( self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, path=self.base_path, json_data=body, timeout_info=timeout_info, ) def collection(self, collection_name: str) -> AstraDBCollection: """ Retrieve a collection from the database. Args: collection_name (str): The name of the collection to retrieve. Returns: AstraDBCollection: The collection object. """ return AstraDBCollection(collection_name=collection_name, astra_db=self) def get_collections( self, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Retrieve a list of collections from the database. Args: options (dict, optional): Options to get the collection list timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: An object containing the list of collections in the database: {"status": {"collections": [...]}} """ # Parse the options parameter if options is None: options = {} json_query = make_payload( top_level="findCollections", options=options, ) response = self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response def create_collection( self, collection_name: str, *, options: Optional[Dict[str, Any]] = None, dimension: Optional[int] = None, metric: Optional[str] = None, service_dict: Optional[Dict[str, str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> AstraDBCollection: """ Create a new collection in the database. Args: collection_name (str): The name of the collection to create. options (dict, optional): Options for the collection. dimension (int, optional): Dimension for vector search. metric (str, optional): Metric choice for vector search. service_dict (dict, optional): a definition for the $vectorize service NOTE: This feature is under current development. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: AstraDBCollection: The created collection object. """ # options from named params vector_options = { k: v for k, v in { "dimension": dimension, "metric": metric, "service": service_dict, }.items() if v is not None } # overlap/merge with stuff in options.vector dup_params = set((options or {}).get("vector", {}).keys()) & set( vector_options.keys() ) # If any params are duplicated, we raise an error if dup_params: dups = ", ".join(sorted(dup_params)) raise ValueError( f"Parameter(s) {dups} passed both to the method and in the options" ) # Build our options dictionary if we have vector options if vector_options: options = options or {} options["vector"] = { **options.get("vector", {}), **vector_options, } # Build the final json payload jsondata = { k: v for k, v in {"name": collection_name, "options": options}.items() if v is not None } # Make the request to the endpoint self._request( method=http_methods.POST, path=f"{self.base_path}", json_data={"createCollection": jsondata}, timeout_info=timeout_info, ) # Get the instance object as the return of the call return AstraDBCollection(astra_db=self, collection_name=collection_name) def delete_collection( self, collection_name: str, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Delete a collection from the database. Args: collection_name (str): The name of the collection to delete. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database. """ # Make sure we provide a collection name if not collection_name: raise ValueError("Must provide a collection name") response = self._request( method=http_methods.POST, path=f"{self.base_path}", json_data={"deleteCollection": {"name": collection_name}}, timeout_info=timeout_info, ) return response @deprecation.deprecated( # type: ignore deprecated_in="0.7.0", removed_in="1.0.0", current_version=__version__, details="Use the 'AstraDBCollection.clear()' method instead", ) def truncate_collection( self, collection_name: str, timeout_info: TimeoutInfoWideType = None ) -> AstraDBCollection: """ Clear a collection in the database, deleting all stored documents. Args: collection_name (str): The name of the collection to clear. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: collection: an AstraDBCollection instance """ collection = AstraDBCollection( collection_name=collection_name, astra_db=self, ) clear_response = collection.clear(timeout_info=timeout_info) if clear_response.get("status", {}).get("deletedCount") != -1: raise ValueError( f"Could not issue a truncation API command (response: {json.dumps(clear_response)})." ) # return the collection itself return collection class AsyncAstraDB: def __init__( self, token: str, api_endpoint: str, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: """ Initialize an Astra DB instance. Args: token (str): Authentication token for Astra DB. api_endpoint (str): API endpoint URL. api_path (str, optional): used to override default URI construction api_version (str, optional): to override default URI construction namespace (str, optional): Namespace for the database. caller_name (str, optional): identity of the caller ("my_framework") caller_version (str, optional): version of the caller code ("1.0.3") """ self.caller_name = caller_name self.caller_version = caller_version self.client = httpx.AsyncClient() if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") if namespace is None: logger.info( f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" ) namespace = DEFAULT_KEYSPACE_NAME # Store the API token self.token = token self.api_endpoint = api_endpoint # Set the Base URL for the API calls self.base_url = self.api_endpoint.strip("/") # Set the API version and path from the call self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") # Set the namespace self.namespace = namespace # Finally, construct the full base path self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" def __repr__(self) -> str: return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' def __eq__(self, other: Any) -> bool: if isinstance(other, AsyncAstraDB): # work on the "normalized" quantities (stripped, etc) return all( [ self.token == other.token, self.base_url == other.base_url, self.base_path == other.base_path, self.caller_name == other.caller_name, self.caller_version == other.caller_version, ] ) else: return False async def __aenter__(self) -> AsyncAstraDB: return self async def __aexit__( self, exc_type: Optional[Type[BaseException]] = None, exc_value: Optional[BaseException] = None, traceback: Optional[TracebackType] = None, ) -> None: await self.client.aclose() def copy( self, *, token: Optional[str] = None, api_endpoint: Optional[str] = None, api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> AsyncAstraDB: return AsyncAstraDB( token=token or self.token, api_endpoint=api_endpoint or self.base_url, api_path=api_path or self.api_path, api_version=api_version or self.api_version, namespace=namespace or self.namespace, caller_name=caller_name or self.caller_name, caller_version=caller_version or self.caller_version, ) def to_sync(self) -> AstraDB: return AstraDB( token=self.token, api_endpoint=self.base_url, api_path=self.api_path, api_version=self.api_version, namespace=self.namespace, caller_name=self.caller_name, caller_version=self.caller_version, ) def set_caller( self, caller_name: Optional[str] = None, caller_version: Optional[str] = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version async def _request( self, method: str = http_methods.POST, path: Optional[str] = None, json_data: Optional[Dict[str, Any]] = None, url_params: Optional[Dict[str, Any]] = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: adirect_response = await async_api_request( client=self.client, base_url=self.base_url, auth_header=DEFAULT_AUTH_HEADER, token=self.token, method=method, json_data=normalize_for_api(json_data), url_params=url_params, path=path, skip_error_check=skip_error_check, caller_name=self.caller_name, caller_version=self.caller_version, timeout=to_httpx_timeout(timeout_info), ) response = restore_from_api(adirect_response) return response async def post_raw_request( self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, path=self.base_path, json_data=body, timeout_info=timeout_info, ) async def collection(self, collection_name: str) -> AsyncAstraDBCollection: """ Retrieve a collection from the database. Args: collection_name (str): The name of the collection to retrieve. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: AstraDBCollection: The collection object. """ return AsyncAstraDBCollection(collection_name=collection_name, astra_db=self) async def get_collections( self, options: Optional[Dict[str, Any]] = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ Retrieve a list of collections from the database. Args: options (dict, optional): Options to get the collection list timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: An object containing the list of collections in the database: {"status": {"collections": [...]}} """ # Parse the options parameter if options is None: options = {} json_query = make_payload( top_level="findCollections", options=options, ) response = await self._request( method=http_methods.POST, path=self.base_path, json_data=json_query, timeout_info=timeout_info, ) return response async def create_collection( self, collection_name: str, *, options: Optional[Dict[str, Any]] = None, dimension: Optional[int] = None, metric: Optional[str] = None, service_dict: Optional[Dict[str, str]] = None, timeout_info: TimeoutInfoWideType = None, ) -> AsyncAstraDBCollection: """ Create a new collection in the database. Args: collection_name (str): The name of the collection to create. options (dict, optional): Options for the collection. dimension (int, optional): Dimension for vector search. metric (str, optional): Metric choice for vector search. service_dict (dict, optional): a definition for the $vectorize service NOTE: This feature is under current development. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: AsyncAstraDBCollection: The created collection object. """ # options from named params vector_options = { k: v for k, v in { "dimension": dimension, "metric": metric, "service": service_dict, }.items() if v is not None } # overlap/merge with stuff in options.vector dup_params = set((options or {}).get("vector", {}).keys()) & set( vector_options.keys() ) # If any params are duplicated, we raise an error if dup_params: dups = ", ".join(sorted(dup_params)) raise ValueError( f"Parameter(s) {dups} passed both to the method and in the options" ) # Build our options dictionary if we have vector options if vector_options: options = options or {} options["vector"] = { **options.get("vector", {}), **vector_options, } # Build the final json payload jsondata = { k: v for k, v in {"name": collection_name, "options": options}.items() if v is not None } # Make the request to the endpoint await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data={"createCollection": jsondata}, timeout_info=timeout_info, ) # Get the instance object as the return of the call return AsyncAstraDBCollection(astra_db=self, collection_name=collection_name) async def delete_collection( self, collection_name: str, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Delete a collection from the database. Args: collection_name (str): The name of the collection to delete. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: dict: The response from the database. """ # Make sure we provide a collection name if not collection_name: raise ValueError("Must provide a collection name") response = await self._request( method=http_methods.POST, path=f"{self.base_path}", json_data={"deleteCollection": {"name": collection_name}}, timeout_info=timeout_info, ) return response @deprecation.deprecated( # type: ignore deprecated_in="0.7.0", removed_in="1.0.0", current_version=__version__, details="Use the 'AsyncAstraDBCollection.clear()' method instead", ) async def truncate_collection( self, collection_name: str, timeout_info: TimeoutInfoWideType = None ) -> AsyncAstraDBCollection: """ Clear a collection in the database, deleting all stored documents. Args: collection_name (str): The name of the collection to clear. timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. Note that a 'read' timeout event will not block the action taken by the API server if it has received the request already. Returns: collection: an AsyncAstraDBCollection instance """ collection = AsyncAstraDBCollection( collection_name=collection_name, astra_db=self, ) clear_response = await collection.clear(timeout_info=timeout_info) if clear_response.get("status", {}).get("deletedCount") != -1: raise ValueError( f"Could not issue a truncation API command (response: {json.dumps(clear_response)})." ) # return the collection itself return collection