File size: 3,823 Bytes
e619418 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import asyncio
import dataclasses as dc
import json
from collections import defaultdict
from collections.abc import Awaitable, Callable
from typing import Any
import httpx
import httpx_sse
def _new_future() -> asyncio.Future[Any]:
return asyncio.get_running_loop().create_future()
@dc.dataclass(kw_only=True)
class EditorAPIContext:
uri: str
user: str
password: str
token: str | None = None
verify: bool | str = True
_client: httpx.AsyncClient | None = None
sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future))
async def __aenter__(self) -> httpx.AsyncClient:
if self._client:
return self._client
self._client = httpx.AsyncClient(verify=self.verify)
return self._client
async def __aexit__(self, *args: Any) -> None:
if self._client:
await self._client.__aexit__(*args)
self._client = None
@property
def auth_headers(self) -> dict[str, str]:
assert self.token
return {"Authorization": f"Bearer {self.token}"}
async def login(self) -> None:
async with self as client:
response = await client.post(
f"{self.uri}/auth/login",
json={"username": self.user, "password": self.password},
)
response.raise_for_status()
self.token = response.json()["token"]
async def sse_loop(self) -> None:
async with self as client:
response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers)
response.raise_for_status()
sub_token = response.json()["token"]
url = f"{self.uri}/sub/{sub_token}"
async with (
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
httpx_sse.aconnect_sse(c, "GET", url) as es,
):
future = self.sse_futures["_sse_loop"]
future.set_result({"status": "ok"})
async for sse in es.aiter_sse():
jdata = json.loads(sse.data)
future = self.sse_futures[jdata["state"]]
future.set_result(jdata)
async def sse_await(self, state_id: str) -> None:
future = self.sse_futures[state_id]
jdata = await future
if jdata["status"] != "ok":
print("ERROR", jdata)
assert jdata["status"] == "ok"
del self.sse_futures[state_id]
async def get_meta(self, state_id: str) -> dict[str, Any]:
async with self as client:
response = await client.get(
f"{self.uri}/state/meta/{state_id}",
headers=self.auth_headers,
)
response.raise_for_status()
return response.json()
async def run_one[Tin, Tout](
self,
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
params: Tin,
) -> Tout:
await self.login()
async with asyncio.TaskGroup() as tg:
sse_task = tg.create_task(self.sse_loop())
async def outer_co(params: Tin) -> Tout:
# _sse_loop is a fake event to wait until the SSE loop is properly setup.
await self.sse_await("_sse_loop")
r = await co(self, params)
sse_task.cancel()
return r
r = tg.create_task(outer_co(params))
return r.result()
def run_one_sync[Tin, Tout](
self,
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
params: Tin,
) -> Tout:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.run_one(co, params))
|