File size: 13,214 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
from typing import (
    Any,
    cast,
    Dict,
    Iterable,
    List,
    Optional,
    TypedDict,
    Union,
)
import json
import time
import datetime
import logging
import copy

import httpx

from astrapy import __version__
from astrapy.core.defaults import DEFAULT_AUTH_HEADER, DEFAULT_TIMEOUT
from astrapy.core.core_types import API_RESPONSE
from astrapy.core.ids import ObjectId, UUID


class CustomLogger(logging.Logger):
    def trace(self, msg: str, *args: Any, **kwargs: Any) -> None:
        if self.isEnabledFor(5):
            self._log(5, msg, args, **kwargs)


# Add a new TRACE logging level
logging.addLevelName(5, "TRACE")

# Tell the logging system to use your custom logger
logging.setLoggerClass(CustomLogger)


logger = logging.getLogger(__name__)


class http_methods:
    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    PATCH = "PATCH"
    DELETE = "DELETE"


package_name = __name__.split(".")[0]

user_agent_astrapy = f"{package_name}/{__version__}"


def detect_ragstack_user_agent() -> Optional[str]:
    from importlib import metadata
    from importlib.metadata import PackageNotFoundError

    try:
        ragstack_meta = metadata.metadata("ragstack-ai")
        if ragstack_meta:
            ragstack_version = ragstack_meta["version"]
            return f"ragstack/{ragstack_version}"
    except PackageNotFoundError:
        pass
    return None


user_agent_rs = detect_ragstack_user_agent()


def log_request(
    method: str,
    url: str,
    params: Optional[Dict[str, Any]],
    headers: Dict[str, str],
    json_data: Optional[Dict[str, Any]],
) -> None:
    """
    Log the details of an HTTP request for debugging purposes.

    Args:
        json_data (dict or None): The JSON payload sent with the request, if any.
    """
    logger.debug(f"Request url: {url}")
    logger.debug(f"Request method: {method}")
    logger.debug(f"Request params: {params}")

    # Redact the token from the request headers
    headers_log = copy.deepcopy(headers)
    if DEFAULT_AUTH_HEADER in headers_log:
        headers_log[DEFAULT_AUTH_HEADER] = "AstraCS:<...>"

    logger.debug(f"Request headers: {headers_log}")

    if json_data:
        logger.trace(f"Request payload: {json_data}")  # type: ignore


def log_response(r: httpx.Response) -> None:
    """
    Log the details of an HTTP response for debugging purposes.

    Args:
        r (requests.Response): The response object from the HTTP request.
    """
    logger.debug(f"Response status code: {r.status_code}")
    logger.debug(f"Response headers: {r.headers}")
    logger.debug(f"Response content: {r.text}")


def compose_user_agent(
    caller_name: Optional[str], caller_version: Optional[str]
) -> str:
    user_agent_caller: Optional[str] = None
    if caller_name:
        if caller_version:
            user_agent_caller = f"{caller_name}/{caller_version}"
        else:
            user_agent_caller = f"{caller_name}"
    all_user_agents = [
        ua_block
        for ua_block in [
            user_agent_rs,
            user_agent_caller,
            user_agent_astrapy,
        ]
        if ua_block
    ]
    return " ".join(all_user_agents)


class TimeoutInfo(TypedDict, total=False):
    read: float
    write: float
    base: float


TimeoutInfoWideType = Union[TimeoutInfo, float, None]


def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]:
    if timeout_info is None:
        return None
    if isinstance(timeout_info, float) or isinstance(timeout_info, int):
        return httpx.Timeout(timeout_info)
    elif isinstance(timeout_info, dict):
        _base = timeout_info.get("base") or DEFAULT_TIMEOUT
        _read = timeout_info.get("read") or _base
        _write = timeout_info.get("write") or _base
        return httpx.Timeout(_base, read=_read, write=_write)
    else:
        raise ValueError("Invalid timeout info provided.")


def make_request(
    client: httpx.Client,
    base_url: str,
    auth_header: str,
    token: str,
    method: str,
    json_data: Optional[Dict[str, Any]],
    url_params: Optional[Dict[str, Any]],
    path: Optional[str],
    caller_name: Optional[str],
    caller_version: Optional[str],
    timeout: Optional[Union[httpx.Timeout, float]],
) -> httpx.Response:
    """
    Make an HTTP request to a specified URL.

    Args:
        client (httpx): The httpx client for the request.
        base_url (str): The base URL for the request.
        auth_header (str): The authentication header key.
        token (str): The token used for authentication.
        method (str, optional): The HTTP method to use for the request. Default is POST.
        path (str, optional): The specific path to append to the base URL.
        json_data (dict, optional): JSON payload to be sent with the request.
        url_params (dict, optional): URL parameters to be sent with the request.

    Returns:
        requests.Response: The response from the HTTP request.
    """
    # Build the request headers from the token and user agent
    request_headers = {
        auth_header: token,
        "User-Agent": compose_user_agent(caller_name, caller_version),
    }

    # Log the parameters of the request accordingly
    log_request(method, f"{base_url}{path}", url_params, request_headers, json_data)

    # Make the request
    r = client.request(
        method=method,
        url=f"{base_url}{path}",
        params=url_params,
        content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(),
        timeout=timeout or DEFAULT_TIMEOUT,
        headers=request_headers,
    )

    # Log the response before returning it
    log_response(r)

    return r


async def amake_request(
    client: httpx.AsyncClient,
    base_url: str,
    auth_header: str,
    token: str,
    method: str,
    path: Optional[str],
    json_data: Optional[Dict[str, Any]],
    url_params: Optional[Dict[str, Any]],
    caller_name: Optional[str],
    caller_version: Optional[str],
    timeout: Optional[Union[httpx.Timeout, float]],
) -> httpx.Response:
    """
    Make an HTTP request to a specified URL.

    Args:
        client (httpx): The httpx client for the request.
        base_url (str): The base URL for the request.
        auth_header (str): The authentication header key.
        token (str): The token used for authentication.
        method (str, optional): The HTTP method to use for the request. Default is POST.
        path (str, optional): The specific path to append to the base URL.
        json_data (dict, optional): JSON payload to be sent with the request.
        url_params (dict, optional): URL parameters to be sent with the request.

    Returns:
        requests.Response: The response from the HTTP request.
    """
    # Build the request headers from the token and user agent
    request_headers = {
        auth_header: token,
        "User-Agent": compose_user_agent(caller_name, caller_version),
    }

    # Log the parameters of the request accordingly
    log_request(method, f"{base_url}{path}", url_params, request_headers, json_data)

    # Make the request
    r = await client.request(
        method=method,
        url=f"{base_url}{path}",
        params=url_params,
        content=json.dumps(json_data, allow_nan=False, separators=(",", ":")).encode(),
        timeout=timeout or DEFAULT_TIMEOUT,
        headers=request_headers,
    )

    # Log the response before returning it
    log_response(r)

    return r


def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]:
    """
    Construct a JSON payload for an HTTP request with a specified top-level key.

    Args:
        top_level (str): The top-level key for the JSON payload.
        **kwargs: Arbitrary keyword arguments representing other keys and their values to be included in the payload.

    Returns:
        dict: The constructed JSON payload.
    """
    params = {}
    for key, value in kwargs.items():
        params[key] = value

    json_query: Dict[str, Any] = {top_level: {}}

    # Adding keys only if they're provided
    for key, value in params.items():
        if value is not None:
            json_query[top_level][key] = value

    return json_query


def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]:
    """
    Convert a vector of strings to a vector of floats.

    Args:
        vector (list): A vector of objects.

    Returns:
        list: A vector of floats.
    """
    return [float(value) for value in vector]


def is_list_of_floats(vector: Iterable[Any]) -> bool:
    """
    Safely determine if it's a list of floats.
    Assumption: if list, and first item is float, then all items are.
    """
    return isinstance(vector, list) and (
        len(vector) == 0 or isinstance(vector[0], float) or isinstance(vector[0], int)
    )


def convert_to_ejson_date_object(
    date_value: Union[datetime.date, datetime.datetime]
) -> Dict[str, int]:
    return {"$date": int(time.mktime(date_value.timetuple()) * 1000)}


def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]:
    return {"$uuid": str(uuid_value)}


def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]:
    return {"$objectId": str(objectid_value)}


def convert_ejson_date_object_to_datetime(
    date_object: Dict[str, int]
) -> datetime.datetime:
    return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0)


def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID:
    return UUID(uuid_object["$uuid"])


def convert_ejson_objectid_object_to_objectid(
    objectid_object: Dict[str, str]
) -> ObjectId:
    return ObjectId(objectid_object["$objectId"])


def _normalize_payload_value(path: List[str], value: Any) -> Any:
    """
    The path helps determining special treatments
    """
    _l2 = ".".join(path[-2:])
    _l1 = ".".join(path[-1:])
    if _l1 == "$vector" and _l2 != "projection.$vector":
        if not is_list_of_floats(value):
            return convert_vector_to_floats(value)
        else:
            return value
    else:
        if isinstance(value, dict):
            return {
                k: _normalize_payload_value(path + [k], v) for k, v in value.items()
            }
        elif isinstance(value, list):
            return [
                _normalize_payload_value(path + [""], list_item) for list_item in value
            ]
        else:
            if isinstance(value, datetime.datetime) or isinstance(value, datetime.date):
                return convert_to_ejson_date_object(value)
            elif isinstance(value, UUID):
                return convert_to_ejson_uuid_object(value)
            elif isinstance(value, ObjectId):
                return convert_to_ejson_objectid_object(value)
            else:
                return value


def normalize_for_api(
    payload: Union[Dict[str, Any], None]
) -> Union[Dict[str, Any], None]:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key
    are made into plain lists of floats.

    Args:
        payload (Dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        Dict[str, Any]: a "normalized" payload dict
    """

    if payload:
        return cast(Dict[str, Any], _normalize_payload_value([], payload))
    else:
        return payload


def _restore_response_value(path: List[str], value: Any) -> Any:
    """
    The path helps determining special treatments
    """
    if isinstance(value, dict):
        if len(value) == 1 and "$date" in value:
            # this is `{"$date": 123456}`, restore to datetime.datetime
            return convert_ejson_date_object_to_datetime(value)
        elif len(value) == 1 and "$uuid" in value:
            # this is `{"$uuid": "abc123..."}`, restore to UUID
            return convert_ejson_uuid_object_to_uuid(value)
        elif len(value) == 1 and "$objectId" in value:
            # this is `{"$objectId": "123abc..."}`, restore to ObjectId
            return convert_ejson_objectid_object_to_objectid(value)
        else:
            return {k: _restore_response_value(path + [k], v) for k, v in value.items()}
    elif isinstance(value, list):
        return [_restore_response_value(path + [""], list_item) for list_item in value]
    else:
        return value


def restore_from_api(response: API_RESPONSE) -> API_RESPONSE:
    """
    Process a dictionary just returned from the API.
    This is the place where e.g. `{"$date": 123}` is
    converted back into a datetime object.
    """
    return cast(API_RESPONSE, _restore_response_value([], response))