Spaces:
Running
Running
# Copyright DataStax, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import asyncio | |
import deprecation | |
import httpx | |
import logging | |
import json | |
import threading | |
from collections.abc import ( | |
AsyncGenerator, | |
AsyncIterator, | |
Generator, | |
Iterator, | |
) | |
from concurrent.futures import ThreadPoolExecutor | |
from functools import partial | |
from queue import Queue | |
from types import TracebackType | |
from typing import ( | |
Any, | |
cast, | |
Dict, | |
List, | |
Optional, | |
Tuple, | |
Union, | |
Type, | |
) | |
from astrapy import __version__ | |
from astrapy.core.api import APIRequestError, api_request, async_api_request | |
from astrapy.core.defaults import ( | |
DEFAULT_AUTH_HEADER, | |
DEFAULT_JSON_API_PATH, | |
DEFAULT_JSON_API_VERSION, | |
DEFAULT_KEYSPACE_NAME, | |
MAX_INSERT_NUM_DOCUMENTS, | |
) | |
from astrapy.core.utils import ( | |
convert_vector_to_floats, | |
make_payload, | |
normalize_for_api, | |
restore_from_api, | |
http_methods, | |
to_httpx_timeout, | |
TimeoutInfoWideType, | |
) | |
from astrapy.core.core_types import ( | |
API_DOC, | |
API_RESPONSE, | |
PaginableRequestMethod, | |
AsyncPaginableRequestMethod, | |
) | |
logger = logging.getLogger(__name__) | |
class AstraDBCollection: | |
# Initialize the shared httpx client as a class attribute | |
client = httpx.Client() | |
def __init__( | |
self, | |
collection_name: str, | |
astra_db: Optional[AstraDB] = None, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
""" | |
Initialize an AstraDBCollection instance. | |
Args: | |
collection_name (str): The name of the collection. | |
astra_db (AstraDB, optional): An instance of Astra DB. | |
token (str, optional): Authentication token for Astra DB. | |
api_endpoint (str, optional): API endpoint URL. | |
namespace (str, optional): Namespace for the database. | |
caller_name (str, optional): identity of the caller ("my_framework") | |
If passing a client, its caller is used as fallback | |
caller_version (str, optional): version of the caller code ("1.0.3") | |
If passing a client, its caller is used as fallback | |
""" | |
# Check for presence of the Astra DB object | |
if astra_db is None: | |
if token is None or api_endpoint is None: | |
raise AssertionError("Must provide token and api_endpoint") | |
astra_db = AstraDB( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
else: | |
# if astra_db passed, copy and apply possible overrides | |
astra_db = astra_db.copy( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
# Set the remaining instance attributes | |
self.astra_db = astra_db | |
self.caller_name: Optional[str] = self.astra_db.caller_name | |
self.caller_version: Optional[str] = self.astra_db.caller_version | |
self.collection_name = collection_name | |
self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" | |
def __repr__(self) -> str: | |
return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' | |
def __eq__(self, other: Any) -> bool: | |
if isinstance(other, AstraDBCollection): | |
return all( | |
[ | |
self.collection_name == other.collection_name, | |
self.astra_db == other.astra_db, | |
self.caller_name == other.caller_name, | |
self.caller_version == other.caller_version, | |
] | |
) | |
else: | |
return False | |
def copy( | |
self, | |
*, | |
collection_name: Optional[str] = None, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> AstraDBCollection: | |
return AstraDBCollection( | |
collection_name=collection_name or self.collection_name, | |
astra_db=self.astra_db.copy( | |
token=token, | |
api_endpoint=api_endpoint, | |
api_path=api_path, | |
api_version=api_version, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
), | |
caller_name=caller_name or self.caller_name, | |
caller_version=caller_version or self.caller_version, | |
) | |
def to_async(self) -> AsyncAstraDBCollection: | |
return AsyncAstraDBCollection( | |
collection_name=self.collection_name, | |
astra_db=self.astra_db.to_async(), | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
) | |
def set_caller( | |
self, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
self.astra_db.set_caller( | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
def _request( | |
self, | |
method: str = http_methods.POST, | |
path: Optional[str] = None, | |
json_data: Optional[Dict[str, Any]] = None, | |
url_params: Optional[Dict[str, Any]] = None, | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
direct_response = api_request( | |
client=self.client, | |
base_url=self.astra_db.base_url, | |
auth_header=DEFAULT_AUTH_HEADER, | |
token=self.astra_db.token, | |
method=method, | |
json_data=normalize_for_api(json_data), | |
url_params=url_params, | |
path=path, | |
skip_error_check=skip_error_check, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
timeout=to_httpx_timeout(timeout_info), | |
) | |
response = restore_from_api(direct_response) | |
return response | |
def post_raw_request( | |
self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
return self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=body, | |
timeout_info=timeout_info, | |
) | |
def _get( | |
self, | |
path: Optional[str] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Optional[API_RESPONSE]: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = self._request( | |
method=http_methods.GET, | |
path=full_path, | |
url_params=options, | |
timeout_info=timeout_info, | |
) | |
if isinstance(response, dict): | |
return response | |
return None | |
def _put( | |
self, | |
path: Optional[str] = None, | |
document: Optional[API_RESPONSE] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = self._request( | |
method=http_methods.PUT, | |
path=full_path, | |
json_data=document, | |
timeout_info=timeout_info, | |
) | |
return response | |
def _post( | |
self, | |
path: Optional[str] = None, | |
document: Optional[API_DOC] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = self._request( | |
method=http_methods.POST, | |
path=full_path, | |
json_data=document, | |
timeout_info=timeout_info, | |
) | |
return response | |
def _recast_as_sort_projection( | |
self, vector: List[float], fields: Optional[List[str]] = None | |
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: | |
""" | |
Given a vector and optionally a list of fields, | |
reformulate them as a sort, projection pair for regular | |
'find'-like API calls (with basic validation as well). | |
""" | |
# Must pass a vector | |
if not vector: | |
raise ValueError("Must pass a vector") | |
# Edge case for field selection | |
if fields and "$similarity" in fields: | |
raise ValueError("Please use the `include_similarity` parameter") | |
# Build the new vector parameter | |
sort: Dict[str, Any] = {"$vector": vector} | |
# Build the new fields parameter | |
# Note: do not leave projection={}, make it None | |
# (or it will devour $similarity away in the API response) | |
if fields is not None and len(fields) > 0: | |
projection = {f: 1 for f in fields} | |
else: | |
projection = None | |
return sort, projection | |
def get( | |
self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None | |
) -> Optional[API_RESPONSE]: | |
""" | |
Retrieve a document from the collection by its path. | |
Args: | |
path (str, optional): The path of the document to retrieve. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The retrieved document. | |
""" | |
return self._get(path=path, timeout_info=timeout_info) | |
def find( | |
self, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find documents in the collection that match the given filter. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return matching documents. | |
options (dict, optional): Additional options for the query. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The query response containing matched documents. | |
""" | |
json_query = make_payload( | |
top_level="find", | |
filter=filter, | |
projection=projection, | |
options=options, | |
sort=sort, | |
) | |
response = self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
def vector_find( | |
self, | |
vector: List[float], | |
*, | |
limit: int, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
include_similarity: bool = True, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[API_DOC]: | |
""" | |
Perform a vector-based search in the collection. | |
Args: | |
vector (list): The vector to search with. | |
limit (int): The maximum number of documents to return. | |
filter (dict, optional): Criteria to filter documents. | |
fields (list, optional): Specifies the fields to return. | |
include_similarity (bool, optional): Whether to include similarity score in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
list: A list of documents matching the vector search criteria. | |
""" | |
# Must pass a limit | |
if not limit: | |
raise ValueError("Must pass a limit") | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
convert_vector_to_floats(vector), | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = self.find( | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
options={ | |
"limit": limit, | |
"includeSimilarity": include_similarity, | |
}, | |
timeout_info=timeout_info, | |
) | |
return cast(List[API_DOC], raw_find_result["data"]["documents"]) | |
def paginate( | |
*, | |
request_method: PaginableRequestMethod, | |
options: Optional[Dict[str, Any]], | |
prefetched: Optional[int] = None, | |
) -> Generator[API_DOC, None, None]: | |
""" | |
Generate paginated results for a given database query method. | |
Args: | |
request_method (function): The database query method to paginate. | |
options (dict, optional): Options for the database query. | |
prefetched (int, optional): Number of pre-fetched documents. | |
Yields: | |
dict: The next document in the paginated result set. | |
""" | |
_options = options or {} | |
response0 = request_method(options=_options) | |
next_page_state = response0["data"]["nextPageState"] | |
options0 = _options | |
if next_page_state is not None and prefetched: | |
def queued_paginate( | |
queue: Queue[Optional[API_DOC]], | |
request_method: PaginableRequestMethod, | |
options: Optional[Dict[str, Any]], | |
) -> None: | |
try: | |
for row in AstraDBCollection.paginate( | |
request_method=request_method, options=options | |
): | |
queue.put(row) | |
finally: | |
queue.put(None) | |
queue: Queue[Optional[API_DOC]] = Queue(prefetched) | |
options1 = {**options0, **{"pageState": next_page_state}} | |
t = threading.Thread( | |
target=queued_paginate, args=(queue, request_method, options1) | |
) | |
t.start() | |
for document in response0["data"]["documents"]: | |
yield document | |
doc = queue.get() | |
while doc is not None: | |
yield doc | |
doc = queue.get() | |
t.join() | |
else: | |
for document in response0["data"]["documents"]: | |
yield document | |
while next_page_state is not None and not prefetched: | |
options1 = {**options0, **{"pageState": next_page_state}} | |
response1 = request_method(options=options1) | |
for document in response1["data"]["documents"]: | |
yield document | |
next_page_state = response1["data"]["nextPageState"] | |
def paginated_find( | |
self, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
prefetched: Optional[int] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Iterator[API_DOC]: | |
""" | |
Perform a paginated search in the collection. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return matching documents. | |
options (dict, optional): Additional options for the query. | |
prefetched (int, optional): Number of pre-fetched documents. | |
timeout_info: a float, or a TimeoutInfo dict, for each | |
single HTTP request. | |
This is a paginated method, that issues several requests as it | |
needs more data. This parameter controls a single request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
generator: A generator yielding documents in the paginated result set. | |
""" | |
partialed_find = partial( | |
self.find, | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
timeout_info=timeout_info, | |
) | |
return self.paginate( | |
request_method=partialed_find, | |
options=options, | |
prefetched=prefetched, | |
) | |
def pop( | |
self, | |
filter: Dict[str, Any], | |
pop: Dict[str, Any], | |
options: Dict[str, Any], | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Pop the last data in the tags array | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
pop (dict): The pop to apply to the tags. | |
options (dict): Additional options for the update operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The original document before the update. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update={"$pop": pop}, | |
options=options, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def push( | |
self, | |
filter: Dict[str, Any], | |
push: Dict[str, Any], | |
options: Dict[str, Any], | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Push new data to the tags array | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
push (dict): The push to apply to the tags. | |
options (dict): Additional options for the update operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the update operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update={"$push": push}, | |
options=options, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def find_one_and_replace( | |
self, | |
replacement: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and replace it. | |
Args: | |
replacement (dict): The new document to replace the existing one. | |
filter (dict, optional): Criteria to filter documents. | |
sort (dict, optional): Specifies the order in which to find the document. | |
options (dict, optional): Additional options for the operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and replace operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndReplace", | |
filter=filter, | |
projection=projection, | |
replacement=replacement, | |
options=options, | |
sort=sort, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def vector_find_one_and_replace( | |
self, | |
vector: List[float], | |
replacement: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search and replace the first matched document. | |
Args: | |
vector (dict): The vector to search with. | |
replacement (dict): The new document to replace the existing one. | |
filter (dict, optional): Criteria to filter documents. | |
fields (list, optional): Specifies the fields to return in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: either the matched document or None if nothing found | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
convert_vector_to_floats(vector), | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = self.find_one_and_replace( | |
replacement=replacement, | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
def find_one_and_update( | |
self, | |
update: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = {}, | |
filter: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and update it. | |
Args: | |
update (dict): The update to apply to the document. | |
sort (dict, optional): Specifies the order in which to find the document. | |
filter (dict, optional): Criteria to filter documents. | |
options (dict, optional): Additional options for the operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and update operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update=update, | |
options=options, | |
sort=sort, | |
projection=projection, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def vector_find_one_and_update( | |
self, | |
vector: List[float], | |
update: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search and update the first matched document. | |
Args: | |
vector (list): The vector to search with. | |
update (dict): The update to apply to the matched document. | |
filter (dict, optional): Criteria to filter documents before applying the vector search. | |
fields (list, optional): Specifies the fields to return in the updated document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: The result of the vector-based find and | |
update operation, or None if nothing found | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
convert_vector_to_floats(vector), | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = self.find_one_and_update( | |
update=update, | |
filter=filter, | |
sort=sort, | |
projection=projection, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
def find_one_and_delete( | |
self, | |
sort: Optional[Dict[str, Any]] = {}, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and delete it. | |
Args: | |
sort (dict, optional): Specifies the order in which to find the document. | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and delete operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndDelete", | |
filter=filter, | |
sort=sort, | |
projection=projection, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def count_documents( | |
self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Count documents matching a given predicate (expressed as filter). | |
Args: | |
filter (dict, defaults to {}): Criteria to filter documents. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: the response, either | |
{"status": {"count": <NUMBER> }} | |
or | |
{"errors": [...]} | |
""" | |
json_query = make_payload( | |
top_level="countDocuments", | |
filter=filter, | |
) | |
response = self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
def find_one( | |
self, | |
filter: Optional[Dict[str, Any]] = {}, | |
projection: Optional[Dict[str, Any]] = {}, | |
sort: Optional[Dict[str, Any]] = {}, | |
options: Optional[Dict[str, Any]] = {}, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document in the collection. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return the document. | |
options (dict, optional): Additional options for the query. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: the response, either | |
{"data": {"document": <DOCUMENT> }} | |
or | |
{"data": {"document": None}} | |
depending on whether a matching document is found or not. | |
""" | |
json_query = make_payload( | |
top_level="findOne", | |
filter=filter, | |
projection=projection, | |
options=options, | |
sort=sort, | |
) | |
response = self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
def vector_find_one( | |
self, | |
vector: List[float], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
include_similarity: bool = True, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search to find a single document in the collection. | |
Args: | |
vector (list): The vector to search with. | |
filter (dict, optional): Additional criteria to filter documents. | |
fields (list, optional): Specifies the fields to return in the result. | |
include_similarity (bool, optional): Whether to include similarity score in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: The found document or None if no matching document is found. | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
convert_vector_to_floats(vector), | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = self.find_one( | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
options={"includeSimilarity": include_similarity}, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
def insert_one( | |
self, | |
document: API_DOC, | |
failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Insert a single document into the collection. | |
Args: | |
document (dict): The document to insert. | |
failures_allowed (bool): Whether to allow failures in the insert operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the insert operation. | |
""" | |
json_query = make_payload(top_level="insertOne", document=document) | |
response = self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
skip_error_check=failures_allowed, | |
timeout_info=timeout_info, | |
) | |
return response | |
def insert_many( | |
self, | |
documents: List[API_DOC], | |
options: Optional[Dict[str, Any]] = None, | |
partial_failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Insert multiple documents into the collection. | |
Args: | |
documents (list): A list of documents to insert. | |
options (dict, optional): Additional options for the insert operation. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures through the insertion (i.e. on some documents). | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the insert operation. | |
""" | |
json_query = make_payload( | |
top_level="insertMany", documents=documents, options=options | |
) | |
# Send the data | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
skip_error_check=partial_failures_allowed, | |
timeout_info=timeout_info, | |
) | |
return response | |
def chunked_insert_many( | |
self, | |
documents: List[API_DOC], | |
options: Optional[Dict[str, Any]] = None, | |
partial_failures_allowed: bool = False, | |
chunk_size: int = MAX_INSERT_NUM_DOCUMENTS, | |
concurrency: int = 1, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[Union[API_RESPONSE, Exception]]: | |
""" | |
Insert multiple documents into the collection, handling chunking and | |
optionally with concurrent insertions. | |
Args: | |
documents (list): A list of documents to insert. | |
options (dict, optional): Additional options for the insert operation. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures in the chunk. Should be used combined with | |
options={"ordered": False} in most cases. | |
chunk_size (int, optional): Override the default insertion chunk size. | |
concurrency (int, optional): The number of concurrent chunk insertions. | |
Default is no concurrency. | |
timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. | |
This method runs a number of HTTP requests as it works on chunked | |
data. The timeout refers to each individual such request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
list: The responses from the database after the chunked insert operation. | |
This is a list of individual responses from the API: the caller | |
will need to inspect them all, e.g. to collate the inserted IDs. | |
""" | |
results: List[Union[API_RESPONSE, Exception]] = [] | |
# Raise a warning if ordered and concurrency | |
if options and options.get("ordered") is True and concurrency > 1: | |
logger.warning( | |
"Using ordered insert with concurrency may lead to unexpected results." | |
) | |
# If we have concurrency as 1, don't use a thread pool | |
if concurrency == 1: | |
# Split the documents into chunks | |
for i in range(0, len(documents), chunk_size): | |
try: | |
results.append( | |
self.insert_many( | |
documents[i : i + chunk_size], | |
options, | |
partial_failures_allowed, | |
timeout_info=timeout_info, | |
) | |
) | |
except APIRequestError as e: | |
if partial_failures_allowed: | |
results.append(e) | |
else: | |
raise e | |
return results | |
# Perform the bulk insert with concurrency otherwise | |
with ThreadPoolExecutor(max_workers=concurrency) as executor: | |
# Submit the jobs | |
futures = [ | |
executor.submit( | |
self.insert_many, | |
documents[i : i + chunk_size], | |
options, | |
partial_failures_allowed, | |
timeout_info=timeout_info, | |
) | |
for i in range(0, len(documents), chunk_size) | |
] | |
# Collect the results | |
for future in futures: | |
try: | |
results.append(future.result()) | |
except APIRequestError as e: | |
if partial_failures_allowed: | |
results.append(e) | |
else: | |
raise e | |
return results | |
def update_one( | |
self, | |
filter: Dict[str, Any], | |
update: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Update a single document in the collection. | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
update (dict): The update to apply to the document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = make_payload( | |
top_level="updateOne", | |
filter=filter, | |
update=update, | |
sort=sort, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def update_many( | |
self, | |
filter: Dict[str, Any], | |
update: Dict[str, Any], | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Updates multiple documents in the collection. | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
update (dict): The update to apply to the document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = make_payload( | |
top_level="updateMany", | |
filter=filter, | |
update=update, | |
options=options, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def replace( | |
self, path: str, document: API_DOC, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Replace a document in the collection. | |
Args: | |
path (str): The path to the document to replace. | |
document (dict): The new document to replace the existing one. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the replace operation. | |
""" | |
return self._put(path=path, document=document, timeout_info=timeout_info) | |
def delete(self, id: str, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: | |
return self.delete_one(id, timeout_info=timeout_info) | |
def delete_one( | |
self, | |
id: str, | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete a single document from the collection based on its ID. | |
Args: | |
id (str): The ID of the document to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = make_payload( | |
top_level="deleteOne", | |
filter={"_id": id}, | |
sort=sort, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def delete_one_by_predicate( | |
self, | |
filter: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete a single document from the collection based on a filter clause | |
Args: | |
filter: any filter dictionary | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = make_payload( | |
top_level="deleteOne", | |
filter=filter, | |
sort=sort, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def delete_many( | |
self, | |
filter: Dict[str, Any], | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete many documents from the collection based on a filter condition | |
Args: | |
filter (dict): Criteria to identify the documents to delete. | |
skip_error_check (bool): whether to ignore the check for API error | |
and return the response untouched. Default is False. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = { | |
"deleteMany": { | |
"filter": filter, | |
} | |
} | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
skip_error_check=skip_error_check, | |
timeout_info=timeout_info, | |
) | |
return response | |
def chunked_delete_many( | |
self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> List[API_RESPONSE]: | |
""" | |
Delete many documents from the collection based on a filter condition, | |
chaining several API calls until exhaustion of the documents to delete. | |
Args: | |
filter (dict): Criteria to identify the documents to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. | |
This method runs a number of HTTP requests as it works on a | |
pagination basis. The timeout refers to each individual such request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
List[dict]: The responses from the database from all the calls | |
""" | |
responses = [] | |
must_proceed = True | |
while must_proceed: | |
dm_response = self.delete_many(filter=filter, timeout_info=timeout_info) | |
responses.append(dm_response) | |
must_proceed = dm_response.get("status", {}).get("moreData", False) | |
return responses | |
def clear(self, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: | |
""" | |
Clear the collection, deleting all documents | |
Args: | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database. | |
""" | |
clear_response = self.delete_many(filter={}, timeout_info=timeout_info) | |
if clear_response.get("status", {}).get("deletedCount") != -1: | |
raise ValueError( | |
f"Could not issue a clear-collection API command (response: {json.dumps(clear_response)})." | |
) | |
return clear_response | |
def delete_subdocument( | |
self, id: str, subdoc: str, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Delete a subdocument or field from a document in the collection. | |
Args: | |
id (str): The ID of the document containing the subdocument. | |
subdoc (str): The key of the subdocument or field to remove. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = { | |
"findOneAndUpdate": { | |
"filter": {"_id": id}, | |
"update": {"$unset": {subdoc: ""}}, | |
} | |
} | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def upsert( | |
self, document: API_DOC, timeout_info: TimeoutInfoWideType = None | |
) -> str: | |
return self.upsert_one(document, timeout_info=timeout_info) | |
def upsert_one( | |
self, document: API_DOC, timeout_info: TimeoutInfoWideType = None | |
) -> str: | |
""" | |
Emulate an upsert operation for a single document in the collection. | |
This method attempts to insert the document. | |
If a document with the same _id exists, it updates the existing document. | |
Args: | |
document (dict): The document to insert or update. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP requests. | |
This method may issue one or two requests, depending on what | |
is detected on DB. This timeout controls each HTTP request individually. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
str: The _id of the inserted or updated document. | |
""" | |
# Build the payload for the insert attempt | |
result = self.insert_one( | |
document, failures_allowed=True, timeout_info=timeout_info | |
) | |
# If the call failed because of preexisting doc, then we replace it | |
if "errors" in result: | |
if ( | |
"errorCode" in result["errors"][0] | |
and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" | |
): | |
# Now we attempt the update | |
result = self.find_one_and_replace( | |
replacement=document, | |
filter={"_id": document["_id"]}, | |
timeout_info=timeout_info, | |
) | |
upserted_id = cast(str, result["data"]["document"]["_id"]) | |
else: | |
raise ValueError(result) | |
else: | |
if result.get("status", {}).get("insertedIds", []): | |
upserted_id = cast(str, result["status"]["insertedIds"][0]) | |
else: | |
raise ValueError("Unexplained empty insertedIds from API") | |
return upserted_id | |
def upsert_many( | |
self, | |
documents: list[API_DOC], | |
concurrency: int = 1, | |
partial_failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[Union[str, Exception]]: | |
""" | |
Emulate an upsert operation for multiple documents in the collection. | |
This method attempts to insert the documents. | |
If a document with the same _id exists, it updates the existing document. | |
Args: | |
documents (List[dict]): The documents to insert or update. | |
concurrency (int, optional): The number of concurrent upserts. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures in the batch. | |
timeout_info: a float, or a TimeoutInfo dict, for each HTTP request. | |
This method issues a separate HTTP request for each document to | |
insert: the timeout controls each such request individually. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. | |
""" | |
results: List[Union[str, Exception]] = [] | |
# If concurrency is 1, no need for thread pool | |
if concurrency == 1: | |
for document in documents: | |
try: | |
results.append(self.upsert_one(document, timeout_info=timeout_info)) | |
except Exception as e: | |
results.append(e) | |
return results | |
# Perform the bulk upsert with concurrency | |
with ThreadPoolExecutor(max_workers=concurrency) as executor: | |
# Submit the jobs | |
futures = [ | |
executor.submit(self.upsert, document, timeout_info=timeout_info) | |
for document in documents | |
] | |
# Collect the results | |
for future in futures: | |
try: | |
results.append(future.result()) | |
except Exception as e: | |
if partial_failures_allowed: | |
results.append(e) | |
else: | |
raise e | |
return results | |
class AsyncAstraDBCollection: | |
def __init__( | |
self, | |
collection_name: str, | |
astra_db: Optional[AsyncAstraDB] = None, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
""" | |
Initialize an AstraDBCollection instance. | |
Args: | |
collection_name (str): The name of the collection. | |
astra_db (AstraDB, optional): An instance of Astra DB. | |
token (str, optional): Authentication token for Astra DB. | |
api_endpoint (str, optional): API endpoint URL. | |
namespace (str, optional): Namespace for the database. | |
caller_name (str, optional): identity of the caller ("my_framework") | |
If passing a client, its caller is used as fallback | |
caller_version (str, optional): version of the caller code ("1.0.3") | |
If passing a client, its caller is used as fallback | |
""" | |
# Check for presence of the Astra DB object | |
if astra_db is None: | |
if token is None or api_endpoint is None: | |
raise AssertionError("Must provide token and api_endpoint") | |
astra_db = AsyncAstraDB( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
else: | |
# if astra_db passed, copy and apply possible overrides | |
astra_db = astra_db.copy( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
# Set the remaining instance attributes | |
self.astra_db: AsyncAstraDB = astra_db | |
self.caller_name: Optional[str] = self.astra_db.caller_name | |
self.caller_version: Optional[str] = self.astra_db.caller_version | |
self.client = astra_db.client | |
self.collection_name = collection_name | |
self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" | |
def __repr__(self) -> str: | |
return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' | |
def __eq__(self, other: Any) -> bool: | |
if isinstance(other, AsyncAstraDBCollection): | |
return all( | |
[ | |
self.collection_name == other.collection_name, | |
self.astra_db == other.astra_db, | |
self.caller_name == other.caller_name, | |
self.caller_version == other.caller_version, | |
] | |
) | |
else: | |
return False | |
def copy( | |
self, | |
*, | |
collection_name: Optional[str] = None, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> AsyncAstraDBCollection: | |
return AsyncAstraDBCollection( | |
collection_name=collection_name or self.collection_name, | |
astra_db=self.astra_db.copy( | |
token=token, | |
api_endpoint=api_endpoint, | |
api_path=api_path, | |
api_version=api_version, | |
namespace=namespace, | |
caller_name=caller_name, | |
caller_version=caller_version, | |
), | |
caller_name=caller_name or self.caller_name, | |
caller_version=caller_version or self.caller_version, | |
) | |
def set_caller( | |
self, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
self.astra_db.set_caller( | |
caller_name=caller_name, | |
caller_version=caller_version, | |
) | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
def to_sync(self) -> AstraDBCollection: | |
return AstraDBCollection( | |
collection_name=self.collection_name, | |
astra_db=self.astra_db.to_sync(), | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
) | |
async def _request( | |
self, | |
method: str = http_methods.POST, | |
path: Optional[str] = None, | |
json_data: Optional[Dict[str, Any]] = None, | |
url_params: Optional[Dict[str, Any]] = None, | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
**kwargs: Any, | |
) -> API_RESPONSE: | |
adirect_response = await async_api_request( | |
client=self.client, | |
base_url=self.astra_db.base_url, | |
auth_header=DEFAULT_AUTH_HEADER, | |
token=self.astra_db.token, | |
method=method, | |
json_data=normalize_for_api(json_data), | |
url_params=url_params, | |
path=path, | |
skip_error_check=skip_error_check, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
timeout=to_httpx_timeout(timeout_info), | |
) | |
response = restore_from_api(adirect_response) | |
return response | |
async def post_raw_request( | |
self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
return await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=body, | |
timeout_info=timeout_info, | |
) | |
async def _get( | |
self, | |
path: Optional[str] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Optional[API_RESPONSE]: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = await self._request( | |
method=http_methods.GET, | |
path=full_path, | |
url_params=options, | |
timeout_info=timeout_info, | |
) | |
if isinstance(response, dict): | |
return response | |
return None | |
async def _put( | |
self, | |
path: Optional[str] = None, | |
document: Optional[API_RESPONSE] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = await self._request( | |
method=http_methods.PUT, | |
path=full_path, | |
json_data=document, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def _post( | |
self, | |
path: Optional[str] = None, | |
document: Optional[API_DOC] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
full_path = f"{self.base_path}/{path}" if path else self.base_path | |
response = await self._request( | |
method=http_methods.POST, | |
path=full_path, | |
json_data=document, | |
timeout_info=timeout_info, | |
) | |
return response | |
def _recast_as_sort_projection( | |
self, vector: List[float], fields: Optional[List[str]] = None | |
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: | |
""" | |
Given a vector and optionally a list of fields, | |
reformulate them as a sort, projection pair for regular | |
'find'-like API calls (with basic validation as well). | |
""" | |
# Must pass a vector | |
if not vector: | |
raise ValueError("Must pass a vector") | |
# Edge case for field selection | |
if fields and "$similarity" in fields: | |
raise ValueError("Please use the `include_similarity` parameter") | |
# Build the new vector parameter | |
sort: Dict[str, Any] = {"$vector": vector} | |
# Build the new fields parameter | |
# Note: do not leave projection={}, make it None | |
# (or it will devour $similarity away in the API response) | |
if fields is not None and len(fields) > 0: | |
projection = {f: 1 for f in fields} | |
else: | |
projection = None | |
return sort, projection | |
async def get( | |
self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None | |
) -> Optional[API_RESPONSE]: | |
""" | |
Retrieve a document from the collection by its path. | |
Args: | |
path (str, optional): The path of the document to retrieve. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The retrieved document. | |
""" | |
return await self._get(path=path, timeout_info=timeout_info) | |
async def find( | |
self, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find documents in the collection that match the given filter. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return matching documents. | |
options (dict, optional): Additional options for the query. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The query response containing matched documents. | |
""" | |
json_query = make_payload( | |
top_level="find", | |
filter=filter, | |
projection=projection, | |
options=options, | |
sort=sort, | |
) | |
response = await self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
async def vector_find( | |
self, | |
vector: List[float], | |
*, | |
limit: int, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
include_similarity: bool = True, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[API_DOC]: | |
""" | |
Perform a vector-based search in the collection. | |
Args: | |
vector (list): The vector to search with. | |
limit (int): The maximum number of documents to return. | |
filter (dict, optional): Criteria to filter documents. | |
fields (list, optional): Specifies the fields to return. | |
include_similarity (bool, optional): Whether to include similarity score in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
list: A list of documents matching the vector search criteria. | |
""" | |
# Must pass a limit | |
if not limit: | |
raise ValueError("Must pass a limit") | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
vector, | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = await self.find( | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
options={ | |
"limit": limit, | |
"includeSimilarity": include_similarity, | |
}, | |
timeout_info=timeout_info, | |
) | |
return cast(List[API_DOC], raw_find_result["data"]["documents"]) | |
async def paginate( | |
*, | |
request_method: AsyncPaginableRequestMethod, | |
options: Optional[Dict[str, Any]], | |
prefetched: Optional[int] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> AsyncGenerator[API_DOC, None]: | |
""" | |
Generate paginated results for a given database query method. | |
Args: | |
request_method (function): The database query method to paginate. | |
options (dict, optional): Options for the database query. | |
prefetched (int, optional): Number of pre-fetched documents. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Yields: | |
dict: The next document in the paginated result set. | |
""" | |
_options = options or {} | |
response0 = await request_method(options=_options) | |
next_page_state = response0["data"]["nextPageState"] | |
options0 = _options | |
if next_page_state is not None and prefetched: | |
async def queued_paginate( | |
queue: asyncio.Queue[Optional[API_DOC]], | |
request_method: AsyncPaginableRequestMethod, | |
options: Optional[Dict[str, Any]], | |
) -> None: | |
try: | |
async for doc in AsyncAstraDBCollection.paginate( | |
request_method=request_method, options=options | |
): | |
await queue.put(doc) | |
finally: | |
await queue.put(None) | |
queue: asyncio.Queue[Optional[API_DOC]] = asyncio.Queue(prefetched) | |
options1 = {**options0, **{"pageState": next_page_state}} | |
asyncio.create_task(queued_paginate(queue, request_method, options1)) | |
for document in response0["data"]["documents"]: | |
yield document | |
doc = await queue.get() | |
while doc is not None: | |
yield doc | |
doc = await queue.get() | |
else: | |
for document in response0["data"]["documents"]: | |
yield document | |
while next_page_state is not None: | |
options1 = {**options0, **{"pageState": next_page_state}} | |
response1 = await request_method(options=options1) | |
for document in response1["data"]["documents"]: | |
yield document | |
next_page_state = response1["data"]["nextPageState"] | |
def paginated_find( | |
self, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
prefetched: Optional[int] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> AsyncIterator[API_DOC]: | |
""" | |
Perform a paginated search in the collection. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return matching documents. | |
options (dict, optional): Additional options for the query. | |
prefetched (int, optional): Number of pre-fetched documents | |
timeout_info: a float, or a TimeoutInfo dict, for each | |
single HTTP request. | |
This is a paginated method, that issues several requests as it | |
needs more data. This parameter controls a single request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
generator: A generator yielding documents in the paginated result set. | |
""" | |
partialed_find = partial( | |
self.find, | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
timeout_info=timeout_info, | |
) | |
return self.paginate( | |
request_method=partialed_find, | |
options=options, | |
prefetched=prefetched, | |
) | |
async def pop( | |
self, | |
filter: Dict[str, Any], | |
pop: Dict[str, Any], | |
options: Dict[str, Any], | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Pop the last data in the tags array | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
pop (dict): The pop to apply to the tags. | |
options (dict): Additional options for the update operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The original document before the update. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update={"$pop": pop}, | |
options=options, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def push( | |
self, | |
filter: Dict[str, Any], | |
push: Dict[str, Any], | |
options: Dict[str, Any], | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Push new data to the tags array | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
push (dict): The push to apply to the tags. | |
options (dict): Additional options for the update operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the update operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update={"$push": push}, | |
options=options, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def find_one_and_replace( | |
self, | |
replacement: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
sort: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and replace it. | |
Args: | |
replacement (dict): The new document to replace the existing one. | |
filter (dict, optional): Criteria to filter documents. | |
sort (dict, optional): Specifies the order in which to find the document. | |
options (dict, optional): Additional options for the operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and replace operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndReplace", | |
filter=filter, | |
projection=projection, | |
replacement=replacement, | |
options=options, | |
sort=sort, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def vector_find_one_and_replace( | |
self, | |
vector: List[float], | |
replacement: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search and replace the first matched document. | |
Args: | |
vector (dict): The vector to search with. | |
replacement (dict): The new document to replace the existing one. | |
filter (dict, optional): Criteria to filter documents. | |
fields (list, optional): Specifies the fields to return in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: either the matched document or None if nothing found | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
vector, | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = await self.find_one_and_replace( | |
replacement=replacement, | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
async def find_one_and_update( | |
self, | |
update: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = {}, | |
filter: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and update it. | |
Args: | |
sort (dict, optional): Specifies the order in which to find the document. | |
update (dict): The update to apply to the document. | |
filter (dict, optional): Criteria to filter documents. | |
options (dict, optional): Additional options for the operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and update operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndUpdate", | |
filter=filter, | |
update=update, | |
options=options, | |
sort=sort, | |
projection=projection, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def vector_find_one_and_update( | |
self, | |
vector: List[float], | |
update: Dict[str, Any], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search and update the first matched document. | |
Args: | |
vector (list): The vector to search with. | |
update (dict): The update to apply to the matched document. | |
filter (dict, optional): Criteria to filter documents before applying the vector search. | |
fields (list, optional): Specifies the fields to return in the updated document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: The result of the vector-based find and | |
update operation, or None if nothing found | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
vector, | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = await self.find_one_and_update( | |
update=update, | |
filter=filter, | |
sort=sort, | |
projection=projection, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
async def find_one_and_delete( | |
self, | |
sort: Optional[Dict[str, Any]] = {}, | |
filter: Optional[Dict[str, Any]] = None, | |
projection: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document and delete it. | |
Args: | |
sort (dict, optional): Specifies the order in which to find the document. | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The result of the find and delete operation. | |
""" | |
json_query = make_payload( | |
top_level="findOneAndDelete", | |
filter=filter, | |
sort=sort, | |
projection=projection, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def count_documents( | |
self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Count documents matching a given predicate (expressed as filter). | |
Args: | |
filter (dict, defaults to {}): Criteria to filter documents. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: the response, either | |
{"status": {"count": <NUMBER> }} | |
or | |
{"errors": [...]} | |
""" | |
json_query = make_payload( | |
top_level="countDocuments", | |
filter=filter, | |
) | |
response = await self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
async def find_one( | |
self, | |
filter: Optional[Dict[str, Any]] = {}, | |
projection: Optional[Dict[str, Any]] = {}, | |
sort: Optional[Dict[str, Any]] = {}, | |
options: Optional[Dict[str, Any]] = {}, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Find a single document in the collection. | |
Args: | |
filter (dict, optional): Criteria to filter documents. | |
projection (dict, optional): Specifies the fields to return. | |
sort (dict, optional): Specifies the order in which to return the document. | |
options (dict, optional): Additional options for the query. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: the response, either | |
{"data": {"document": <DOCUMENT> }} | |
or | |
{"data": {"document": None}} | |
depending on whether a matching document is found or not. | |
""" | |
json_query = make_payload( | |
top_level="findOne", | |
filter=filter, | |
projection=projection, | |
options=options, | |
sort=sort, | |
) | |
response = await self._post(document=json_query, timeout_info=timeout_info) | |
return response | |
async def vector_find_one( | |
self, | |
vector: List[float], | |
*, | |
filter: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
include_similarity: bool = True, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> Union[API_DOC, None]: | |
""" | |
Perform a vector-based search to find a single document in the collection. | |
Args: | |
vector (list): The vector to search with. | |
filter (dict, optional): Additional criteria to filter documents. | |
fields (list, optional): Specifies the fields to return in the result. | |
include_similarity (bool, optional): Whether to include similarity score in the result. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict or None: The found document or None if no matching document is found. | |
""" | |
# Pre-process the included arguments | |
sort, projection = self._recast_as_sort_projection( | |
vector, | |
fields=fields, | |
) | |
# Call the underlying find() method to search | |
raw_find_result = await self.find_one( | |
filter=filter, | |
projection=projection, | |
sort=sort, | |
options={"includeSimilarity": include_similarity}, | |
timeout_info=timeout_info, | |
) | |
return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) | |
async def insert_one( | |
self, | |
document: API_DOC, | |
failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Insert a single document into the collection. | |
Args: | |
document (dict): The document to insert. | |
failures_allowed (bool): Whether to allow failures in the insert operation. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the insert operation. | |
""" | |
json_query = make_payload(top_level="insertOne", document=document) | |
response = await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
skip_error_check=failures_allowed, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def insert_many( | |
self, | |
documents: List[API_DOC], | |
options: Optional[Dict[str, Any]] = None, | |
partial_failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Insert multiple documents into the collection. | |
Args: | |
documents (list): A list of documents to insert. | |
options (dict, optional): Additional options for the insert operation. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures through the insertion (i.e. on some documents). | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the insert operation. | |
""" | |
json_query = make_payload( | |
top_level="insertMany", documents=documents, options=options | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
skip_error_check=partial_failures_allowed, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def chunked_insert_many( | |
self, | |
documents: List[API_DOC], | |
options: Optional[Dict[str, Any]] = None, | |
partial_failures_allowed: bool = False, | |
chunk_size: int = MAX_INSERT_NUM_DOCUMENTS, | |
concurrency: int = 1, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[Union[API_RESPONSE, Exception]]: | |
""" | |
Insert multiple documents into the collection, handling chunking and | |
optionally with concurrent insertions. | |
Args: | |
documents (list): A list of documents to insert. | |
options (dict, optional): Additional options for the insert operation. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures in the chunk. Should be used combined with | |
options={"ordered": False} in most cases. | |
chunk_size (int, optional): Override the default insertion chunk size. | |
concurrency (int, optional): The number of concurrent chunk insertions. | |
Default is no concurrency. | |
timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. | |
This method runs a number of HTTP requests as it works on chunked | |
data. The timeout refers to each individual such request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
list: The responses from the database after the chunked insert operation. | |
This is a list of individual responses from the API: the caller | |
will need to inspect them all, e.g. to collate the inserted IDs. | |
""" | |
sem = asyncio.Semaphore(concurrency) | |
async def concurrent_insert_many( | |
docs: List[API_DOC], | |
index: int, | |
partial_failures_allowed: bool, | |
) -> Union[API_RESPONSE, Exception]: | |
async with sem: | |
logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") | |
try: | |
return await self.insert_many( | |
documents=docs, | |
options=options, | |
partial_failures_allowed=partial_failures_allowed, | |
timeout_info=timeout_info, | |
) | |
except APIRequestError as e: | |
if partial_failures_allowed: | |
return e | |
else: | |
raise e | |
if concurrency > 1: | |
tasks = [ | |
asyncio.create_task( | |
concurrent_insert_many( | |
documents[i : i + chunk_size], i, partial_failures_allowed | |
) | |
) | |
for i in range(0, len(documents), chunk_size) | |
] | |
results = await asyncio.gather(*tasks, return_exceptions=False) | |
else: | |
# this ensures the expectation of | |
# "sequential strictly obeys fail-fast if ordered and concurrency==1" | |
results = [ | |
await concurrent_insert_many( | |
documents[i : i + chunk_size], i, partial_failures_allowed | |
) | |
for i in range(0, len(documents), chunk_size) | |
] | |
return results | |
async def update_one( | |
self, | |
filter: Dict[str, Any], | |
update: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Update a single document in the collection. | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
update (dict): The update to apply to the document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = make_payload( | |
top_level="updateOne", | |
filter=filter, | |
update=update, | |
sort=sort, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def update_many( | |
self, | |
filter: Dict[str, Any], | |
update: Dict[str, Any], | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Updates multiple documents in the collection. | |
Args: | |
filter (dict): Criteria to identify the document to update. | |
update (dict): The update to apply to the document. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = make_payload( | |
top_level="updateMany", | |
filter=filter, | |
update=update, | |
options=options, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def replace( | |
self, path: str, document: API_DOC, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Replace a document in the collection. | |
Args: | |
path (str): The path to the document to replace. | |
document (dict): The new document to replace the existing one. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the replace operation. | |
""" | |
return await self._put(path=path, document=document, timeout_info=timeout_info) | |
async def delete_one( | |
self, | |
id: str, | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete a single document from the collection based on its ID. | |
Args: | |
id (str): The ID of the document to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = make_payload( | |
top_level="deleteOne", | |
filter={"_id": id}, | |
sort=sort, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def delete_one_by_predicate( | |
self, | |
filter: Dict[str, Any], | |
sort: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete a single document from the collection based on a filter clause | |
Args: | |
filter: any filter dictionary | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = make_payload( | |
top_level="deleteOne", | |
filter=filter, | |
sort=sort, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def delete_many( | |
self, | |
filter: Dict[str, Any], | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Delete many documents from the collection based on a filter condition | |
Args: | |
filter (dict): Criteria to identify the documents to delete. | |
skip_error_check (bool): whether to ignore the check for API error | |
and return the response untouched. Default is False. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the delete operation. | |
""" | |
json_query = { | |
"deleteMany": { | |
"filter": filter, | |
} | |
} | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
skip_error_check=skip_error_check, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def chunked_delete_many( | |
self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> List[API_RESPONSE]: | |
""" | |
Delete many documents from the collection based on a filter condition, | |
chaining several API calls until exhaustion of the documents to delete. | |
Args: | |
filter (dict): Criteria to identify the documents to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for each single HTTP request. | |
This method runs a number of HTTP requests as it works on a | |
pagination basis. The timeout refers to each individual such request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
List[dict]: The responses from the database from all the calls | |
""" | |
responses = [] | |
must_proceed = True | |
while must_proceed: | |
dm_response = await self.delete_many( | |
filter=filter, timeout_info=timeout_info | |
) | |
responses.append(dm_response) | |
must_proceed = dm_response.get("status", {}).get("moreData", False) | |
return responses | |
async def clear(self, timeout_info: TimeoutInfoWideType = None) -> API_RESPONSE: | |
""" | |
Clear the collection, deleting all documents | |
Args: | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database. | |
""" | |
clear_response = await self.delete_many(filter={}, timeout_info=timeout_info) | |
if clear_response.get("status", {}).get("deletedCount") != -1: | |
raise ValueError( | |
f"Could not issue a clear-collection API command (response: {json.dumps(clear_response)})." | |
) | |
return clear_response | |
async def delete_subdocument( | |
self, id: str, subdoc: str, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Delete a subdocument or field from a document in the collection. | |
Args: | |
id (str): The ID of the document containing the subdocument. | |
subdoc (str): The key of the subdocument or field to remove. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database after the update operation. | |
""" | |
json_query = { | |
"findOneAndUpdate": { | |
"filter": {"_id": id}, | |
"update": {"$unset": {subdoc: ""}}, | |
} | |
} | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def upsert( | |
self, document: API_DOC, timeout_info: TimeoutInfoWideType = None | |
) -> str: | |
return await self.upsert_one(document, timeout_info=timeout_info) | |
async def upsert_one( | |
self, | |
document: API_DOC, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> str: | |
""" | |
Emulate an upsert operation for a single document in the collection. | |
This method attempts to insert the document. | |
If a document with the same _id exists, it updates the existing document. | |
Args: | |
document (dict): The document to insert or update. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP requests. | |
This method may issue one or two requests, depending on what | |
is detected on DB. This timeout controls each HTTP request individually. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
str: The _id of the inserted or updated document. | |
""" | |
# Build the payload for the insert attempt | |
result = await self.insert_one( | |
document, failures_allowed=True, timeout_info=timeout_info | |
) | |
# If the call failed because of preexisting doc, then we replace it | |
if "errors" in result: | |
if ( | |
"errorCode" in result["errors"][0] | |
and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" | |
): | |
# Now we attempt the update | |
result = await self.find_one_and_replace( | |
replacement=document, | |
filter={"_id": document["_id"]}, | |
timeout_info=timeout_info, | |
) | |
upserted_id = cast(str, result["data"]["document"]["_id"]) | |
else: | |
raise ValueError(result) | |
else: | |
if result.get("status", {}).get("insertedIds", []): | |
upserted_id = cast(str, result["status"]["insertedIds"][0]) | |
else: | |
raise ValueError("Unexplained empty insertedIds from API") | |
return upserted_id | |
async def upsert_many( | |
self, | |
documents: list[API_DOC], | |
concurrency: int = 1, | |
partial_failures_allowed: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> List[Union[str, Exception]]: | |
""" | |
Emulate an upsert operation for multiple documents in the collection. | |
This method attempts to insert the documents. | |
If a document with the same _id exists, it updates the existing document. | |
Args: | |
documents (List[dict]): The documents to insert or update. | |
concurrency (int, optional): The number of concurrent upserts. | |
partial_failures_allowed (bool, optional): Whether to allow partial | |
failures in the batch. | |
timeout_info: a float, or a TimeoutInfo dict, for each HTTP request. | |
This method issues a separate HTTP request for each document to | |
insert: the timeout controls each such request individually. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. | |
""" | |
sem = asyncio.Semaphore(concurrency) | |
async def concurrent_upsert(doc: API_DOC) -> str: | |
async with sem: | |
return await self.upsert_one(document=doc, timeout_info=timeout_info) | |
tasks = [asyncio.create_task(concurrent_upsert(doc)) for doc in documents] | |
results = await asyncio.gather( | |
*tasks, return_exceptions=partial_failures_allowed | |
) | |
for result in results: | |
if isinstance(result, BaseException) and not isinstance(result, Exception): | |
raise result | |
return results # type: ignore | |
class AstraDB: | |
# Initialize the shared httpx client as a class attribute | |
client = httpx.Client() | |
def __init__( | |
self, | |
token: str, | |
api_endpoint: str, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
""" | |
Initialize an Astra DB instance. | |
Args: | |
token (str): Authentication token for Astra DB. | |
api_endpoint (str): API endpoint URL. | |
api_path (str, optional): used to override default URI construction | |
api_version (str, optional): to override default URI construction | |
namespace (str, optional): Namespace for the database. | |
caller_name (str, optional): identity of the caller ("my_framework") | |
caller_version (str, optional): version of the caller code ("1.0.3") | |
""" | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
if token is None or api_endpoint is None: | |
raise AssertionError("Must provide token and api_endpoint") | |
if namespace is None: | |
logger.info( | |
f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" | |
) | |
namespace = DEFAULT_KEYSPACE_NAME | |
# Store the API token | |
self.token = token | |
self.api_endpoint = api_endpoint | |
# Set the Base URL for the API calls | |
self.base_url = self.api_endpoint.strip("/") | |
# Set the API version and path from the call | |
self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") | |
self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") | |
# Set the namespace | |
self.namespace = namespace | |
# Finally, construct the full base path | |
self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" | |
def __repr__(self) -> str: | |
return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' | |
def __eq__(self, other: Any) -> bool: | |
if isinstance(other, AstraDB): | |
# work on the "normalized" quantities (stripped, etc) | |
return all( | |
[ | |
self.token == other.token, | |
self.base_url == other.base_url, | |
self.base_path == other.base_path, | |
self.caller_name == other.caller_name, | |
self.caller_version == other.caller_version, | |
] | |
) | |
else: | |
return False | |
def copy( | |
self, | |
*, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> AstraDB: | |
return AstraDB( | |
token=token or self.token, | |
api_endpoint=api_endpoint or self.base_url, | |
api_path=api_path or self.api_path, | |
api_version=api_version or self.api_version, | |
namespace=namespace or self.namespace, | |
caller_name=caller_name or self.caller_name, | |
caller_version=caller_version or self.caller_version, | |
) | |
def to_async(self) -> AsyncAstraDB: | |
return AsyncAstraDB( | |
token=self.token, | |
api_endpoint=self.base_url, | |
api_path=self.api_path, | |
api_version=self.api_version, | |
namespace=self.namespace, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
) | |
def set_caller( | |
self, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
def _request( | |
self, | |
method: str = http_methods.POST, | |
path: Optional[str] = None, | |
json_data: Optional[Dict[str, Any]] = None, | |
url_params: Optional[Dict[str, Any]] = None, | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
direct_response = api_request( | |
client=self.client, | |
base_url=self.base_url, | |
auth_header=DEFAULT_AUTH_HEADER, | |
token=self.token, | |
method=method, | |
json_data=normalize_for_api(json_data), | |
url_params=url_params, | |
path=path, | |
skip_error_check=skip_error_check, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
timeout=to_httpx_timeout(timeout_info), | |
) | |
response = restore_from_api(direct_response) | |
return response | |
def post_raw_request( | |
self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
return self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=body, | |
timeout_info=timeout_info, | |
) | |
def collection(self, collection_name: str) -> AstraDBCollection: | |
""" | |
Retrieve a collection from the database. | |
Args: | |
collection_name (str): The name of the collection to retrieve. | |
Returns: | |
AstraDBCollection: The collection object. | |
""" | |
return AstraDBCollection(collection_name=collection_name, astra_db=self) | |
def get_collections( | |
self, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Retrieve a list of collections from the database. | |
Args: | |
options (dict, optional): Options to get the collection list | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: An object containing the list of collections in the database: | |
{"status": {"collections": [...]}} | |
""" | |
# Parse the options parameter | |
if options is None: | |
options = {} | |
json_query = make_payload( | |
top_level="findCollections", | |
options=options, | |
) | |
response = self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
def create_collection( | |
self, | |
collection_name: str, | |
*, | |
options: Optional[Dict[str, Any]] = None, | |
dimension: Optional[int] = None, | |
metric: Optional[str] = None, | |
service_dict: Optional[Dict[str, str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> AstraDBCollection: | |
""" | |
Create a new collection in the database. | |
Args: | |
collection_name (str): The name of the collection to create. | |
options (dict, optional): Options for the collection. | |
dimension (int, optional): Dimension for vector search. | |
metric (str, optional): Metric choice for vector search. | |
service_dict (dict, optional): a definition for the $vectorize service | |
NOTE: This feature is under current development. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
AstraDBCollection: The created collection object. | |
""" | |
# options from named params | |
vector_options = { | |
k: v | |
for k, v in { | |
"dimension": dimension, | |
"metric": metric, | |
"service": service_dict, | |
}.items() | |
if v is not None | |
} | |
# overlap/merge with stuff in options.vector | |
dup_params = set((options or {}).get("vector", {}).keys()) & set( | |
vector_options.keys() | |
) | |
# If any params are duplicated, we raise an error | |
if dup_params: | |
dups = ", ".join(sorted(dup_params)) | |
raise ValueError( | |
f"Parameter(s) {dups} passed both to the method and in the options" | |
) | |
# Build our options dictionary if we have vector options | |
if vector_options: | |
options = options or {} | |
options["vector"] = { | |
**options.get("vector", {}), | |
**vector_options, | |
} | |
# Build the final json payload | |
jsondata = { | |
k: v | |
for k, v in {"name": collection_name, "options": options}.items() | |
if v is not None | |
} | |
# Make the request to the endpoint | |
self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data={"createCollection": jsondata}, | |
timeout_info=timeout_info, | |
) | |
# Get the instance object as the return of the call | |
return AstraDBCollection(astra_db=self, collection_name=collection_name) | |
def delete_collection( | |
self, collection_name: str, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Delete a collection from the database. | |
Args: | |
collection_name (str): The name of the collection to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database. | |
""" | |
# Make sure we provide a collection name | |
if not collection_name: | |
raise ValueError("Must provide a collection name") | |
response = self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data={"deleteCollection": {"name": collection_name}}, | |
timeout_info=timeout_info, | |
) | |
return response | |
def truncate_collection( | |
self, collection_name: str, timeout_info: TimeoutInfoWideType = None | |
) -> AstraDBCollection: | |
""" | |
Clear a collection in the database, deleting all stored documents. | |
Args: | |
collection_name (str): The name of the collection to clear. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
collection: an AstraDBCollection instance | |
""" | |
collection = AstraDBCollection( | |
collection_name=collection_name, | |
astra_db=self, | |
) | |
clear_response = collection.clear(timeout_info=timeout_info) | |
if clear_response.get("status", {}).get("deletedCount") != -1: | |
raise ValueError( | |
f"Could not issue a truncation API command (response: {json.dumps(clear_response)})." | |
) | |
# return the collection itself | |
return collection | |
class AsyncAstraDB: | |
def __init__( | |
self, | |
token: str, | |
api_endpoint: str, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
""" | |
Initialize an Astra DB instance. | |
Args: | |
token (str): Authentication token for Astra DB. | |
api_endpoint (str): API endpoint URL. | |
api_path (str, optional): used to override default URI construction | |
api_version (str, optional): to override default URI construction | |
namespace (str, optional): Namespace for the database. | |
caller_name (str, optional): identity of the caller ("my_framework") | |
caller_version (str, optional): version of the caller code ("1.0.3") | |
""" | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
self.client = httpx.AsyncClient() | |
if token is None or api_endpoint is None: | |
raise AssertionError("Must provide token and api_endpoint") | |
if namespace is None: | |
logger.info( | |
f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" | |
) | |
namespace = DEFAULT_KEYSPACE_NAME | |
# Store the API token | |
self.token = token | |
self.api_endpoint = api_endpoint | |
# Set the Base URL for the API calls | |
self.base_url = self.api_endpoint.strip("/") | |
# Set the API version and path from the call | |
self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") | |
self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") | |
# Set the namespace | |
self.namespace = namespace | |
# Finally, construct the full base path | |
self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" | |
def __repr__(self) -> str: | |
return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' | |
def __eq__(self, other: Any) -> bool: | |
if isinstance(other, AsyncAstraDB): | |
# work on the "normalized" quantities (stripped, etc) | |
return all( | |
[ | |
self.token == other.token, | |
self.base_url == other.base_url, | |
self.base_path == other.base_path, | |
self.caller_name == other.caller_name, | |
self.caller_version == other.caller_version, | |
] | |
) | |
else: | |
return False | |
async def __aenter__(self) -> AsyncAstraDB: | |
return self | |
async def __aexit__( | |
self, | |
exc_type: Optional[Type[BaseException]] = None, | |
exc_value: Optional[BaseException] = None, | |
traceback: Optional[TracebackType] = None, | |
) -> None: | |
await self.client.aclose() | |
def copy( | |
self, | |
*, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
api_path: Optional[str] = None, | |
api_version: Optional[str] = None, | |
namespace: Optional[str] = None, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> AsyncAstraDB: | |
return AsyncAstraDB( | |
token=token or self.token, | |
api_endpoint=api_endpoint or self.base_url, | |
api_path=api_path or self.api_path, | |
api_version=api_version or self.api_version, | |
namespace=namespace or self.namespace, | |
caller_name=caller_name or self.caller_name, | |
caller_version=caller_version or self.caller_version, | |
) | |
def to_sync(self) -> AstraDB: | |
return AstraDB( | |
token=self.token, | |
api_endpoint=self.base_url, | |
api_path=self.api_path, | |
api_version=self.api_version, | |
namespace=self.namespace, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
) | |
def set_caller( | |
self, | |
caller_name: Optional[str] = None, | |
caller_version: Optional[str] = None, | |
) -> None: | |
self.caller_name = caller_name | |
self.caller_version = caller_version | |
async def _request( | |
self, | |
method: str = http_methods.POST, | |
path: Optional[str] = None, | |
json_data: Optional[Dict[str, Any]] = None, | |
url_params: Optional[Dict[str, Any]] = None, | |
skip_error_check: bool = False, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
adirect_response = await async_api_request( | |
client=self.client, | |
base_url=self.base_url, | |
auth_header=DEFAULT_AUTH_HEADER, | |
token=self.token, | |
method=method, | |
json_data=normalize_for_api(json_data), | |
url_params=url_params, | |
path=path, | |
skip_error_check=skip_error_check, | |
caller_name=self.caller_name, | |
caller_version=self.caller_version, | |
timeout=to_httpx_timeout(timeout_info), | |
) | |
response = restore_from_api(adirect_response) | |
return response | |
async def post_raw_request( | |
self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
return await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=body, | |
timeout_info=timeout_info, | |
) | |
async def collection(self, collection_name: str) -> AsyncAstraDBCollection: | |
""" | |
Retrieve a collection from the database. | |
Args: | |
collection_name (str): The name of the collection to retrieve. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
AstraDBCollection: The collection object. | |
""" | |
return AsyncAstraDBCollection(collection_name=collection_name, astra_db=self) | |
async def get_collections( | |
self, | |
options: Optional[Dict[str, Any]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> API_RESPONSE: | |
""" | |
Retrieve a list of collections from the database. | |
Args: | |
options (dict, optional): Options to get the collection list | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: An object containing the list of collections in the database: | |
{"status": {"collections": [...]}} | |
""" | |
# Parse the options parameter | |
if options is None: | |
options = {} | |
json_query = make_payload( | |
top_level="findCollections", | |
options=options, | |
) | |
response = await self._request( | |
method=http_methods.POST, | |
path=self.base_path, | |
json_data=json_query, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def create_collection( | |
self, | |
collection_name: str, | |
*, | |
options: Optional[Dict[str, Any]] = None, | |
dimension: Optional[int] = None, | |
metric: Optional[str] = None, | |
service_dict: Optional[Dict[str, str]] = None, | |
timeout_info: TimeoutInfoWideType = None, | |
) -> AsyncAstraDBCollection: | |
""" | |
Create a new collection in the database. | |
Args: | |
collection_name (str): The name of the collection to create. | |
options (dict, optional): Options for the collection. | |
dimension (int, optional): Dimension for vector search. | |
metric (str, optional): Metric choice for vector search. | |
service_dict (dict, optional): a definition for the $vectorize service | |
NOTE: This feature is under current development. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
AsyncAstraDBCollection: The created collection object. | |
""" | |
# options from named params | |
vector_options = { | |
k: v | |
for k, v in { | |
"dimension": dimension, | |
"metric": metric, | |
"service": service_dict, | |
}.items() | |
if v is not None | |
} | |
# overlap/merge with stuff in options.vector | |
dup_params = set((options or {}).get("vector", {}).keys()) & set( | |
vector_options.keys() | |
) | |
# If any params are duplicated, we raise an error | |
if dup_params: | |
dups = ", ".join(sorted(dup_params)) | |
raise ValueError( | |
f"Parameter(s) {dups} passed both to the method and in the options" | |
) | |
# Build our options dictionary if we have vector options | |
if vector_options: | |
options = options or {} | |
options["vector"] = { | |
**options.get("vector", {}), | |
**vector_options, | |
} | |
# Build the final json payload | |
jsondata = { | |
k: v | |
for k, v in {"name": collection_name, "options": options}.items() | |
if v is not None | |
} | |
# Make the request to the endpoint | |
await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data={"createCollection": jsondata}, | |
timeout_info=timeout_info, | |
) | |
# Get the instance object as the return of the call | |
return AsyncAstraDBCollection(astra_db=self, collection_name=collection_name) | |
async def delete_collection( | |
self, collection_name: str, timeout_info: TimeoutInfoWideType = None | |
) -> API_RESPONSE: | |
""" | |
Delete a collection from the database. | |
Args: | |
collection_name (str): The name of the collection to delete. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
dict: The response from the database. | |
""" | |
# Make sure we provide a collection name | |
if not collection_name: | |
raise ValueError("Must provide a collection name") | |
response = await self._request( | |
method=http_methods.POST, | |
path=f"{self.base_path}", | |
json_data={"deleteCollection": {"name": collection_name}}, | |
timeout_info=timeout_info, | |
) | |
return response | |
async def truncate_collection( | |
self, collection_name: str, timeout_info: TimeoutInfoWideType = None | |
) -> AsyncAstraDBCollection: | |
""" | |
Clear a collection in the database, deleting all stored documents. | |
Args: | |
collection_name (str): The name of the collection to clear. | |
timeout_info: a float, or a TimeoutInfo dict, for the HTTP request. | |
Note that a 'read' timeout event will not block the action taken | |
by the API server if it has received the request already. | |
Returns: | |
collection: an AsyncAstraDBCollection instance | |
""" | |
collection = AsyncAstraDBCollection( | |
collection_name=collection_name, | |
astra_db=self, | |
) | |
clear_response = await collection.clear(timeout_info=timeout_info) | |
if clear_response.get("status", {}).get("deletedCount") != -1: | |
raise ValueError( | |
f"Could not issue a truncation API command (response: {json.dumps(clear_response)})." | |
) | |
# return the collection itself | |
return collection | |