import asyncio import dataclasses as dc import json from collections import defaultdict from collections.abc import Awaitable, Callable from typing import Any, Literal import httpx import httpx_sse Priority = Literal["low", "standard", "high"] def _new_future() -> asyncio.Future[Any]: return asyncio.get_running_loop().create_future() @dc.dataclass(kw_only=True) class EditorAPIContext: uri: str user: str password: str priority: Priority = "standard" token: str | None = None verify: bool | str = True _client: httpx.AsyncClient | None = None sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future)) async def __aenter__(self) -> httpx.AsyncClient: if self._client: return self._client self._client = httpx.AsyncClient(verify=self.verify) return self._client async def __aexit__(self, *args: Any) -> None: if self._client: await self._client.__aexit__(*args) self._client = None @property def auth_headers(self) -> dict[str, str]: assert self.token return {"Authorization": f"Bearer {self.token}"} async def login(self) -> None: async with self as client: response = await client.post( f"{self.uri}/auth/login", json={"username": self.user, "password": self.password}, ) response.raise_for_status() self.token = response.json()["token"] async def sse_loop(self) -> None: async with self as client: response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers) response.raise_for_status() sub_token = response.json()["token"] url = f"{self.uri}/sub/{sub_token}" async with ( httpx.AsyncClient(timeout=None, verify=self.verify) as c, httpx_sse.aconnect_sse(c, "GET", url) as es, ): future = self.sse_futures["_sse_loop"] future.set_result({"status": "ok"}) async for sse in es.aiter_sse(): jdata = json.loads(sse.data) future = self.sse_futures[jdata["state"]] future.set_result(jdata) async def sse_await(self, state_id: str, timeout: float = 60.0) -> None: future = self.sse_futures[state_id] jdata = await asyncio.wait_for(future, timeout=timeout) if jdata["status"] != "ok": print("ERROR", jdata) assert jdata["status"] == "ok" del self.sse_futures[state_id] async def get_meta(self, state_id: str) -> dict[str, Any]: async with self as client: response = await client.get( f"{self.uri}/state/meta/{state_id}", headers=self.auth_headers, ) response.raise_for_status() return response.json() async def run_one[Tin, Tout]( self, co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]], params: Tin, ) -> Tout: await self.login() async with asyncio.TaskGroup() as tg: sse_task = tg.create_task(self.sse_loop()) async def outer_co(params: Tin) -> Tout: # _sse_loop is a fake event to wait until the SSE loop is properly setup. await self.sse_await("_sse_loop") r = await co(self, params) sse_task.cancel() return r r = tg.create_task(outer_co(params)) return r.result() def run_one_sync[Tin, Tout]( self, co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]], params: Tin, ) -> Tout: try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.run_one(co, params)) async def call_skill(self, uri: str, params: dict[str, Any] | None) -> str: params = {"priority": self.priority} | (params or {}) async with self as client: response = await client.post( f"{self.uri}/skills/{uri}", json=params, headers=self.auth_headers, ) response.raise_for_status() state_id = response.json()["state"] await self.sse_await(state_id) return state_id