Spaces:
Running
Running
# Copyright DataStax, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
from typing import ( | |
Any, | |
cast, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
TypedDict, | |
Union, | |
) | |
import json | |
import time | |
import datetime | |
import logging | |
import copy | |
import httpx | |
from astrapy import __version__ | |
from astrapy.core.defaults import DEFAULT_AUTH_HEADER, DEFAULT_TIMEOUT | |
from astrapy.core.core_types import API_RESPONSE | |
from astrapy.core.ids import ObjectId, UUID | |
class CustomLogger(logging.Logger): | |
def trace(self, msg: str, *args: Any, **kwargs: Any) -> None: | |
if self.isEnabledFor(5): | |
self._log(5, msg, args, **kwargs) | |
# Add a new TRACE logging level | |
logging.addLevelName(5, "TRACE") | |
# Tell the logging system to use your custom logger | |
logging.setLoggerClass(CustomLogger) | |
logger = logging.getLogger(__name__) | |
class http_methods: | |
GET = "GET" | |
POST = "POST" | |
PUT = "PUT" | |
PATCH = "PATCH" | |
DELETE = "DELETE" | |
package_name = __name__.split(".")[0] | |
user_agent_astrapy = f"{package_name}/{__version__}" | |
def detect_ragstack_user_agent() -> Optional[str]: | |
from importlib import metadata | |
from importlib.metadata import PackageNotFoundError | |
try: | |
ragstack_meta = metadata.metadata("ragstack-ai") | |
if ragstack_meta: | |
ragstack_version = ragstack_meta["version"] | |
return f"ragstack/{ragstack_version}" | |
except PackageNotFoundError: | |
pass | |
return None | |
user_agent_rs = detect_ragstack_user_agent() | |
def log_request( | |
method: str, | |
url: str, | |
params: Optional[Dict[str, Any]], | |
headers: Dict[str, str], | |
json_data: Optional[Dict[str, Any]], | |
) -> None: | |
""" | |
Log the details of an HTTP request for debugging purposes. | |
Args: | |
json_data (dict or None): The JSON payload sent with the request, if any. | |
""" | |
logger.debug(f"Request url: {url}") | |
logger.debug(f"Request method: {method}") | |
logger.debug(f"Request params: {params}") | |
# Redact the token from the request headers | |
headers_log = copy.deepcopy(headers) | |
if DEFAULT_AUTH_HEADER in headers_log: | |
headers_log[DEFAULT_AUTH_HEADER] = "AstraCS:<...>" | |
logger.debug(f"Request headers: {headers_log}") | |
if json_data: | |
logger.trace(f"Request payload: {json_data}") # type: ignore | |
def log_response(r: httpx.Response) -> None: | |
""" | |
Log the details of an HTTP response for debugging purposes. | |
Args: | |
r (requests.Response): The response object from the HTTP request. | |
""" | |
logger.debug(f"Response status code: {r.status_code}") | |
logger.debug(f"Response headers: {r.headers}") | |
logger.debug(f"Response content: {r.text}") | |
def compose_user_agent( | |
caller_name: Optional[str], caller_version: Optional[str] | |
) -> str: | |
user_agent_caller: Optional[str] = None | |
if caller_name: | |
if caller_version: | |
user_agent_caller = f"{caller_name}/{caller_version}" | |
else: | |
user_agent_caller = f"{caller_name}" | |
all_user_agents = [ | |
ua_block | |
for ua_block in [ | |
user_agent_rs, | |
user_agent_caller, | |
user_agent_astrapy, | |
] | |
if ua_block | |
] | |
return " ".join(all_user_agents) | |
class TimeoutInfo(TypedDict, total=False): | |
read: float | |
write: float | |
base: float | |
TimeoutInfoWideType = Union[TimeoutInfo, float, None] | |
def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]: | |
if timeout_info is None: | |
return None | |
if isinstance(timeout_info, float) or isinstance(timeout_info, int): | |
return httpx.Timeout(timeout_info) | |
elif isinstance(timeout_info, dict): | |
_base = timeout_info.get("base") or DEFAULT_TIMEOUT | |
_read = timeout_info.get("read") or _base | |
_write = timeout_info.get("write") or _base | |
return httpx.Timeout(_base, read=_read, write=_write) | |
else: | |
raise ValueError("Invalid timeout info provided.") | |
def make_request( | |
client: httpx.Client, | |
base_url: str, | |
auth_header: str, | |
token: str, | |
method: str, | |
json_data: Optional[Dict[str, Any]], | |
url_params: Optional[Dict[str, Any]], | |
path: Optional[str], | |
caller_name: Optional[str], | |
caller_version: Optional[str], | |
timeout: Optional[Union[httpx.Timeout, float]], | |
) -> httpx.Response: | |
""" | |
Make an HTTP request to a specified URL. | |
Args: | |
client (httpx): The httpx client for the request. | |
base_url (str): The base URL for the request. | |
auth_header (str): The authentication header key. | |
token (str): The token used for authentication. | |
method (str, optional): The HTTP method to use for the request. Default is POST. | |
path (str, optional): The specific path to append to the base URL. | |
json_data (dict, optional): JSON payload to be sent with the request. | |
url_params (dict, optional): URL parameters to be sent with the request. | |
Returns: | |
requests.Response: The response from the HTTP request. | |
""" | |
# Build the request headers from the token and user agent | |
request_headers = { | |
auth_header: token, | |
"User-Agent": compose_user_agent(caller_name, caller_version), | |
} | |
# Log the parameters of the request accordingly | |
log_request(method, f"{base_url}{path}", url_params, request_headers, json_data) | |
# Make the request | |
r = client.request( | |
method=method, | |
url=f"{base_url}{path}", | |
params=url_params, | |
content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(), | |
timeout=timeout or DEFAULT_TIMEOUT, | |
headers=request_headers, | |
) | |
# Log the response before returning it | |
log_response(r) | |
return r | |
async def amake_request( | |
client: httpx.AsyncClient, | |
base_url: str, | |
auth_header: str, | |
token: str, | |
method: str, | |
path: Optional[str], | |
json_data: Optional[Dict[str, Any]], | |
url_params: Optional[Dict[str, Any]], | |
caller_name: Optional[str], | |
caller_version: Optional[str], | |
timeout: Optional[Union[httpx.Timeout, float]], | |
) -> httpx.Response: | |
""" | |
Make an HTTP request to a specified URL. | |
Args: | |
client (httpx): The httpx client for the request. | |
base_url (str): The base URL for the request. | |
auth_header (str): The authentication header key. | |
token (str): The token used for authentication. | |
method (str, optional): The HTTP method to use for the request. Default is POST. | |
path (str, optional): The specific path to append to the base URL. | |
json_data (dict, optional): JSON payload to be sent with the request. | |
url_params (dict, optional): URL parameters to be sent with the request. | |
Returns: | |
requests.Response: The response from the HTTP request. | |
""" | |
# Build the request headers from the token and user agent | |
request_headers = { | |
auth_header: token, | |
"User-Agent": compose_user_agent(caller_name, caller_version), | |
} | |
# Log the parameters of the request accordingly | |
log_request(method, f"{base_url}{path}", url_params, request_headers, json_data) | |
# Make the request | |
r = await client.request( | |
method=method, | |
url=f"{base_url}{path}", | |
params=url_params, | |
content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(), | |
timeout=timeout or DEFAULT_TIMEOUT, | |
headers=request_headers, | |
) | |
# Log the response before returning it | |
log_response(r) | |
return r | |
def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: | |
""" | |
Construct a JSON payload for an HTTP request with a specified top-level key. | |
Args: | |
top_level (str): The top-level key for the JSON payload. | |
**kwargs: Arbitrary keyword arguments representing other keys and their values to be included in the payload. | |
Returns: | |
dict: The constructed JSON payload. | |
""" | |
params = {} | |
for key, value in kwargs.items(): | |
params[key] = value | |
json_query: Dict[str, Any] = {top_level: {}} | |
# Adding keys only if they're provided | |
for key, value in params.items(): | |
if value is not None: | |
json_query[top_level][key] = value | |
return json_query | |
def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: | |
""" | |
Convert a vector of strings to a vector of floats. | |
Args: | |
vector (list): A vector of objects. | |
Returns: | |
list: A vector of floats. | |
""" | |
return [float(value) for value in vector] | |
def is_list_of_floats(vector: Iterable[Any]) -> bool: | |
""" | |
Safely determine if it's a list of floats. | |
Assumption: if list, and first item is float, then all items are. | |
""" | |
return isinstance(vector, list) and ( | |
len(vector) == 0 or isinstance(vector[0], float) or isinstance(vector[0], int) | |
) | |
def convert_to_ejson_date_object( | |
date_value: Union[datetime.date, datetime.datetime] | |
) -> Dict[str, int]: | |
return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} | |
def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]: | |
return {"$uuid": str(uuid_value)} | |
def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]: | |
return {"$objectId": str(objectid_value)} | |
def convert_ejson_date_object_to_datetime( | |
date_object: Dict[str, int] | |
) -> datetime.datetime: | |
return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) | |
def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID: | |
return UUID(uuid_object["$uuid"]) | |
def convert_ejson_objectid_object_to_objectid( | |
objectid_object: Dict[str, str] | |
) -> ObjectId: | |
return ObjectId(objectid_object["$objectId"]) | |
def _normalize_payload_value(path: List[str], value: Any) -> Any: | |
""" | |
The path helps determining special treatments | |
""" | |
_l2 = ".".join(path[-2:]) | |
_l1 = ".".join(path[-1:]) | |
if _l1 == "$vector" and _l2 != "projection.$vector": | |
if not is_list_of_floats(value): | |
return convert_vector_to_floats(value) | |
else: | |
return value | |
else: | |
if isinstance(value, dict): | |
return { | |
k: _normalize_payload_value(path + [k], v) for k, v in value.items() | |
} | |
elif isinstance(value, list): | |
return [ | |
_normalize_payload_value(path + [""], list_item) for list_item in value | |
] | |
else: | |
if isinstance(value, datetime.datetime) or isinstance(value, datetime.date): | |
return convert_to_ejson_date_object(value) | |
elif isinstance(value, UUID): | |
return convert_to_ejson_uuid_object(value) | |
elif isinstance(value, ObjectId): | |
return convert_to_ejson_objectid_object(value) | |
else: | |
return value | |
def normalize_for_api( | |
payload: Union[Dict[str, Any], None] | |
) -> Union[Dict[str, Any], None]: | |
""" | |
Normalize a payload for API calls. | |
This includes e.g. ensuring values for "$vector" key | |
are made into plain lists of floats. | |
Args: | |
payload (Dict[str, Any]): A dict expressing a payload for an API call | |
Returns: | |
Dict[str, Any]: a "normalized" payload dict | |
""" | |
if payload: | |
return cast(Dict[str, Any], _normalize_payload_value([], payload)) | |
else: | |
return payload | |
def _restore_response_value(path: List[str], value: Any) -> Any: | |
""" | |
The path helps determining special treatments | |
""" | |
if isinstance(value, dict): | |
if len(value) == 1 and "$date" in value: | |
# this is `{"$date": 123456}`, restore to datetime.datetime | |
return convert_ejson_date_object_to_datetime(value) | |
elif len(value) == 1 and "$uuid" in value: | |
# this is `{"$uuid": "abc123..."}`, restore to UUID | |
return convert_ejson_uuid_object_to_uuid(value) | |
elif len(value) == 1 and "$objectId" in value: | |
# this is `{"$objectId": "123abc..."}`, restore to ObjectId | |
return convert_ejson_objectid_object_to_objectid(value) | |
else: | |
return {k: _restore_response_value(path + [k], v) for k, v in value.items()} | |
elif isinstance(value, list): | |
return [_restore_response_value(path + [""], list_item) for list_item in value] | |
else: | |
return value | |
def restore_from_api(response: API_RESPONSE) -> API_RESPONSE: | |
""" | |
Process a dictionary just returned from the API. | |
This is the place where e.g. `{"$date": 123}` is | |
converted back into a datetime object. | |
""" | |
return cast(API_RESPONSE, _restore_response_value([], response)) | |