|
import asyncio |
|
import dataclasses as dc |
|
import json |
|
from collections import defaultdict |
|
from collections.abc import Awaitable, Callable |
|
from typing import Any, Literal |
|
|
|
import httpx |
|
import httpx_sse |
|
|
|
Priority = Literal["low", "standard", "high"] |
|
|
|
|
|
def _new_future() -> asyncio.Future[Any]: |
|
return asyncio.get_running_loop().create_future() |
|
|
|
|
|
@dc.dataclass(kw_only=True) |
|
class EditorAPIContext: |
|
uri: str |
|
user: str |
|
password: str |
|
priority: Priority = "standard" |
|
token: str | None = None |
|
verify: bool | str = True |
|
_client: httpx.AsyncClient | None = None |
|
|
|
sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future)) |
|
|
|
async def __aenter__(self) -> httpx.AsyncClient: |
|
if self._client: |
|
return self._client |
|
self._client = httpx.AsyncClient(verify=self.verify) |
|
return self._client |
|
|
|
async def __aexit__(self, *args: Any) -> None: |
|
if self._client: |
|
await self._client.__aexit__(*args) |
|
self._client = None |
|
|
|
@property |
|
def auth_headers(self) -> dict[str, str]: |
|
assert self.token |
|
return {"Authorization": f"Bearer {self.token}"} |
|
|
|
async def login(self) -> None: |
|
async with self as client: |
|
response = await client.post( |
|
f"{self.uri}/auth/login", |
|
json={"username": self.user, "password": self.password}, |
|
) |
|
response.raise_for_status() |
|
self.token = response.json()["token"] |
|
|
|
async def sse_loop(self) -> None: |
|
async with self as client: |
|
response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers) |
|
response.raise_for_status() |
|
sub_token = response.json()["token"] |
|
url = f"{self.uri}/sub/{sub_token}" |
|
async with ( |
|
httpx.AsyncClient(timeout=None, verify=self.verify) as c, |
|
httpx_sse.aconnect_sse(c, "GET", url) as es, |
|
): |
|
future = self.sse_futures["_sse_loop"] |
|
future.set_result({"status": "ok"}) |
|
async for sse in es.aiter_sse(): |
|
jdata = json.loads(sse.data) |
|
future = self.sse_futures[jdata["state"]] |
|
future.set_result(jdata) |
|
|
|
async def sse_await(self, state_id: str, timeout: float = 60.0) -> None: |
|
future = self.sse_futures[state_id] |
|
jdata = await asyncio.wait_for(future, timeout=timeout) |
|
if jdata["status"] != "ok": |
|
print("ERROR", jdata) |
|
assert jdata["status"] == "ok" |
|
del self.sse_futures[state_id] |
|
|
|
async def get_meta(self, state_id: str) -> dict[str, Any]: |
|
async with self as client: |
|
response = await client.get( |
|
f"{self.uri}/state/meta/{state_id}", |
|
headers=self.auth_headers, |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
async def run_one[Tin, Tout]( |
|
self, |
|
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]], |
|
params: Tin, |
|
) -> Tout: |
|
await self.login() |
|
async with asyncio.TaskGroup() as tg: |
|
sse_task = tg.create_task(self.sse_loop()) |
|
|
|
async def outer_co(params: Tin) -> Tout: |
|
|
|
await self.sse_await("_sse_loop") |
|
r = await co(self, params) |
|
sse_task.cancel() |
|
return r |
|
|
|
r = tg.create_task(outer_co(params)) |
|
|
|
return r.result() |
|
|
|
def run_one_sync[Tin, Tout]( |
|
self, |
|
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]], |
|
params: Tin, |
|
) -> Tout: |
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
return loop.run_until_complete(self.run_one(co, params)) |
|
|
|
async def call_skill(self, uri: str, params: dict[str, Any] | None) -> str: |
|
params = {"priority": self.priority} | (params or {}) |
|
async with self as client: |
|
response = await client.post( |
|
f"{self.uri}/skills/{uri}", |
|
json=params, |
|
headers=self.auth_headers, |
|
) |
|
response.raise_for_status() |
|
state_id = response.json()["state"] |
|
await self.sse_await(state_id) |
|
return state_id |
|
|