Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import (
Any,
cast,
Dict,
Iterable,
List,
Optional,
TypedDict,
Union,
)
import json
import time
import datetime
import logging
import copy
import httpx
from astrapy import __version__
from astrapy.core.defaults import DEFAULT_AUTH_HEADER, DEFAULT_TIMEOUT
from astrapy.core.core_types import API_RESPONSE
from astrapy.core.ids import ObjectId, UUID
class CustomLogger(logging.Logger):
def trace(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.isEnabledFor(5):
self._log(5, msg, args, **kwargs)
# Add a new TRACE logging level
logging.addLevelName(5, "TRACE")
# Tell the logging system to use your custom logger
logging.setLoggerClass(CustomLogger)
logger = logging.getLogger(__name__)
class http_methods:
GET = "GET"
POST = "POST"
PUT = "PUT"
PATCH = "PATCH"
DELETE = "DELETE"
package_name = __name__.split(".")[0]
user_agent_astrapy = f"{package_name}/{__version__}"
def detect_ragstack_user_agent() -> Optional[str]:
from importlib import metadata
from importlib.metadata import PackageNotFoundError
try:
ragstack_meta = metadata.metadata("ragstack-ai")
if ragstack_meta:
ragstack_version = ragstack_meta["version"]
return f"ragstack/{ragstack_version}"
except PackageNotFoundError:
pass
return None
user_agent_rs = detect_ragstack_user_agent()
def log_request(
method: str,
url: str,
params: Optional[Dict[str, Any]],
headers: Dict[str, str],
json_data: Optional[Dict[str, Any]],
) -> None:
"""
Log the details of an HTTP request for debugging purposes.
Args:
json_data (dict or None): The JSON payload sent with the request, if any.
"""
logger.debug(f"Request url: {url}")
logger.debug(f"Request method: {method}")
logger.debug(f"Request params: {params}")
# Redact the token from the request headers
headers_log = copy.deepcopy(headers)
if DEFAULT_AUTH_HEADER in headers_log:
headers_log[DEFAULT_AUTH_HEADER] = "AstraCS:<...>"
logger.debug(f"Request headers: {headers_log}")
if json_data:
logger.trace(f"Request payload: {json_data}") # type: ignore
def log_response(r: httpx.Response) -> None:
"""
Log the details of an HTTP response for debugging purposes.
Args:
r (requests.Response): The response object from the HTTP request.
"""
logger.debug(f"Response status code: {r.status_code}")
logger.debug(f"Response headers: {r.headers}")
logger.debug(f"Response content: {r.text}")
def compose_user_agent(
caller_name: Optional[str], caller_version: Optional[str]
) -> str:
user_agent_caller: Optional[str] = None
if caller_name:
if caller_version:
user_agent_caller = f"{caller_name}/{caller_version}"
else:
user_agent_caller = f"{caller_name}"
all_user_agents = [
ua_block
for ua_block in [
user_agent_rs,
user_agent_caller,
user_agent_astrapy,
]
if ua_block
]
return " ".join(all_user_agents)
class TimeoutInfo(TypedDict, total=False):
read: float
write: float
base: float
TimeoutInfoWideType = Union[TimeoutInfo, float, None]
def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]:
if timeout_info is None:
return None
if isinstance(timeout_info, float) or isinstance(timeout_info, int):
return httpx.Timeout(timeout_info)
elif isinstance(timeout_info, dict):
_base = timeout_info.get("base") or DEFAULT_TIMEOUT
_read = timeout_info.get("read") or _base
_write = timeout_info.get("write") or _base
return httpx.Timeout(_base, read=_read, write=_write)
else:
raise ValueError("Invalid timeout info provided.")
def make_request(
client: httpx.Client,
base_url: str,
auth_header: str,
token: str,
method: str,
json_data: Optional[Dict[str, Any]],
url_params: Optional[Dict[str, Any]],
path: Optional[str],
caller_name: Optional[str],
caller_version: Optional[str],
timeout: Optional[Union[httpx.Timeout, float]],
) -> httpx.Response:
"""
Make an HTTP request to a specified URL.
Args:
client (httpx): The httpx client for the request.
base_url (str): The base URL for the request.
auth_header (str): The authentication header key.
token (str): The token used for authentication.
method (str, optional): The HTTP method to use for the request. Default is POST.
path (str, optional): The specific path to append to the base URL.
json_data (dict, optional): JSON payload to be sent with the request.
url_params (dict, optional): URL parameters to be sent with the request.
Returns:
requests.Response: The response from the HTTP request.
"""
# Build the request headers from the token and user agent
request_headers = {
auth_header: token,
"User-Agent": compose_user_agent(caller_name, caller_version),
}
# Log the parameters of the request accordingly
log_request(method, f"{base_url}{path}", url_params, request_headers, json_data)
# Make the request
r = client.request(
method=method,
url=f"{base_url}{path}",
params=url_params,
content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(),
timeout=timeout or DEFAULT_TIMEOUT,
headers=request_headers,
)
# Log the response before returning it
log_response(r)
return r
async def amake_request(
client: httpx.AsyncClient,
base_url: str,
auth_header: str,
token: str,
method: str,
path: Optional[str],
json_data: Optional[Dict[str, Any]],
url_params: Optional[Dict[str, Any]],
caller_name: Optional[str],
caller_version: Optional[str],
timeout: Optional[Union[httpx.Timeout, float]],
) -> httpx.Response:
"""
Make an HTTP request to a specified URL.
Args:
client (httpx): The httpx client for the request.
base_url (str): The base URL for the request.
auth_header (str): The authentication header key.
token (str): The token used for authentication.
method (str, optional): The HTTP method to use for the request. Default is POST.
path (str, optional): The specific path to append to the base URL.
json_data (dict, optional): JSON payload to be sent with the request.
url_params (dict, optional): URL parameters to be sent with the request.
Returns:
requests.Response: The response from the HTTP request.
"""
# Build the request headers from the token and user agent
request_headers = {
auth_header: token,
"User-Agent": compose_user_agent(caller_name, caller_version),
}
# Log the parameters of the request accordingly
log_request(method, f"{base_url}{path}", url_params, request_headers, json_data)
# Make the request
r = await client.request(
method=method,
url=f"{base_url}{path}",
params=url_params,
content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(),
timeout=timeout or DEFAULT_TIMEOUT,
headers=request_headers,
)
# Log the response before returning it
log_response(r)
return r
def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]:
"""
Construct a JSON payload for an HTTP request with a specified top-level key.
Args:
top_level (str): The top-level key for the JSON payload.
**kwargs: Arbitrary keyword arguments representing other keys and their values to be included in the payload.
Returns:
dict: The constructed JSON payload.
"""
params = {}
for key, value in kwargs.items():
params[key] = value
json_query: Dict[str, Any] = {top_level: {}}
# Adding keys only if they're provided
for key, value in params.items():
if value is not None:
json_query[top_level][key] = value
return json_query
def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]:
"""
Convert a vector of strings to a vector of floats.
Args:
vector (list): A vector of objects.
Returns:
list: A vector of floats.
"""
return [float(value) for value in vector]
def is_list_of_floats(vector: Iterable[Any]) -> bool:
"""
Safely determine if it's a list of floats.
Assumption: if list, and first item is float, then all items are.
"""
return isinstance(vector, list) and (
len(vector) == 0 or isinstance(vector[0], float) or isinstance(vector[0], int)
)
def convert_to_ejson_date_object(
date_value: Union[datetime.date, datetime.datetime]
) -> Dict[str, int]:
return {"$date": int(time.mktime(date_value.timetuple()) * 1000)}
def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]:
return {"$uuid": str(uuid_value)}
def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]:
return {"$objectId": str(objectid_value)}
def convert_ejson_date_object_to_datetime(
date_object: Dict[str, int]
) -> datetime.datetime:
return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0)
def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID:
return UUID(uuid_object["$uuid"])
def convert_ejson_objectid_object_to_objectid(
objectid_object: Dict[str, str]
) -> ObjectId:
return ObjectId(objectid_object["$objectId"])
def _normalize_payload_value(path: List[str], value: Any) -> Any:
"""
The path helps determining special treatments
"""
_l2 = ".".join(path[-2:])
_l1 = ".".join(path[-1:])
if _l1 == "$vector" and _l2 != "projection.$vector":
if not is_list_of_floats(value):
return convert_vector_to_floats(value)
else:
return value
else:
if isinstance(value, dict):
return {
k: _normalize_payload_value(path + [k], v) for k, v in value.items()
}
elif isinstance(value, list):
return [
_normalize_payload_value(path + [""], list_item) for list_item in value
]
else:
if isinstance(value, datetime.datetime) or isinstance(value, datetime.date):
return convert_to_ejson_date_object(value)
elif isinstance(value, UUID):
return convert_to_ejson_uuid_object(value)
elif isinstance(value, ObjectId):
return convert_to_ejson_objectid_object(value)
else:
return value
def normalize_for_api(
payload: Union[Dict[str, Any], None]
) -> Union[Dict[str, Any], None]:
"""
Normalize a payload for API calls.
This includes e.g. ensuring values for "$vector" key
are made into plain lists of floats.
Args:
payload (Dict[str, Any]): A dict expressing a payload for an API call
Returns:
Dict[str, Any]: a "normalized" payload dict
"""
if payload:
return cast(Dict[str, Any], _normalize_payload_value([], payload))
else:
return payload
def _restore_response_value(path: List[str], value: Any) -> Any:
"""
The path helps determining special treatments
"""
if isinstance(value, dict):
if len(value) == 1 and "$date" in value:
# this is `{"$date": 123456}`, restore to datetime.datetime
return convert_ejson_date_object_to_datetime(value)
elif len(value) == 1 and "$uuid" in value:
# this is `{"$uuid": "abc123..."}`, restore to UUID
return convert_ejson_uuid_object_to_uuid(value)
elif len(value) == 1 and "$objectId" in value:
# this is `{"$objectId": "123abc..."}`, restore to ObjectId
return convert_ejson_objectid_object_to_objectid(value)
else:
return {k: _restore_response_value(path + [k], v) for k, v in value.items()}
elif isinstance(value, list):
return [_restore_response_value(path + [""], list_item) for list_item in value]
else:
return value
def restore_from_api(response: API_RESPONSE) -> API_RESPONSE:
"""
Process a dictionary just returned from the API.
This is the place where e.g. `{"$date": 123}` is
converted back into a datetime object.
"""
return cast(API_RESPONSE, _restore_response_value([], response))