lynxkite / server /crdt.py
darabos's picture
Some nodes are now visible.
a8859a7
raw
history blame
4.79 kB
"""CRDT is used to synchronize workspace state for backend and frontend(s)."""
import asyncio
import contextlib
import enum
import fastapi
import os.path
import pycrdt
import pycrdt_websocket
import pycrdt_websocket.ystore
router = fastapi.APIRouter()
def ws_exception_handler(exception, log):
print("exception", exception)
log.exception(exception)
return True
class WebsocketServer(pycrdt_websocket.WebsocketServer):
async def init_room(self, name):
ystore = pycrdt_websocket.ystore.FileYStore(f"crdt_data/{name}.crdt")
ydoc = pycrdt.Doc()
ydoc["workspace"] = ws = pycrdt.Map()
# Replay updates from the store.
try:
for update, timestamp in [
(item[0], item[-1]) async for item in ystore.read()
]:
ydoc.apply_update(update)
except pycrdt_websocket.ystore.YDocNotFound:
pass
if "nodes" not in ws:
ws["nodes"] = pycrdt.Array()
if "edges" not in ws:
ws["edges"] = pycrdt.Array()
if "env" not in ws:
ws["env"] = "unset"
try_to_load_workspace(ws, name)
room = pycrdt_websocket.YRoom(ystore=ystore, ydoc=ydoc)
room.ws = ws
def on_change(changes):
asyncio.create_task(workspace_changed(changes, ws))
ws.observe_deep(on_change)
return room
async def get_room(self, name: str) -> pycrdt_websocket.YRoom:
if name not in self.rooms:
self.rooms[name] = await self.init_room(name)
room = self.rooms[name]
await self.start_room(room)
return room
websocket_server = WebsocketServer(
# exception_handler=ws_exception_handler,
auto_clean_rooms=False,
)
asgi_server = pycrdt_websocket.ASGIServer(websocket_server)
last_ws_input = None
def clean_input(ws_pyd):
for node in ws_pyd.nodes:
node.data.display = None
node.position.x = 0
node.position.y = 0
if node.model_extra:
for key in list(node.model_extra.keys()):
delattr(node, key)
def crdt_update(crdt_obj, python_obj, boxes=set()):
if isinstance(python_obj, dict):
for key, value in python_obj.items():
if key in boxes:
crdt_obj[key] = value
elif isinstance(value, dict):
if crdt_obj.get(key) is None:
crdt_obj[key] = pycrdt.Map()
crdt_update(crdt_obj[key], value, boxes)
elif isinstance(value, list):
if crdt_obj.get(key) is None:
crdt_obj[key] = pycrdt.Array()
crdt_update(crdt_obj[key], value, boxes)
elif isinstance(value, enum.Enum):
crdt_obj[key] = str(value)
else:
crdt_obj[key] = value
elif isinstance(python_obj, list):
for i, value in enumerate(python_obj):
if isinstance(value, dict):
if i >= len(crdt_obj):
crdt_obj.append(pycrdt.Map())
crdt_update(crdt_obj[i], value, boxes)
elif isinstance(value, list):
if i >= len(crdt_obj):
crdt_obj.append(pycrdt.Array())
crdt_update(crdt_obj[i], value, boxes)
else:
if i >= len(crdt_obj):
crdt_obj.append(value)
else:
crdt_obj[i] = value
else:
raise ValueError("Invalid type:", python_obj)
def try_to_load_workspace(ws, name):
from . import workspace
json_path = f"data/{name}"
if os.path.exists(json_path):
ws_pyd = workspace.load(json_path)
crdt_update(ws, ws_pyd.model_dump(), boxes={"display"})
async def workspace_changed(e, ws_crdt):
global last_ws_input
from . import workspace
ws_pyd = workspace.Workspace.model_validate(ws_crdt.to_py())
clean_input(ws_pyd)
if ws_pyd == last_ws_input:
return
last_ws_input = ws_pyd.model_copy(deep=True)
await workspace.execute(ws_pyd)
for nc, np in zip(ws_crdt["nodes"], ws_pyd.nodes):
if "data" not in nc:
nc["data"] = pycrdt.Map()
# Display is added as an opaque Box.
nc["data"]["display"] = np.data.display
nc["data"]["error"] = np.data.error
@contextlib.asynccontextmanager
async def lifespan(app):
async with websocket_server:
yield
def sanitize_path(path):
return os.path.relpath(os.path.normpath(os.path.join("/", path)), "/")
@router.websocket("/ws/crdt/{room_name}")
async def crdt_websocket(websocket: fastapi.WebSocket, room_name: str):
room_name = sanitize_path(room_name)
await asgi_server({"path": room_name}, websocket._receive, websocket._send)