File size: 4,633 Bytes
a180fd2
d43f961
a180fd2
 
 
 
 
 
 
 
 
 
 
d43f961
a180fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
d43f961
 
 
 
 
 
a0194e7
d43f961
 
 
 
 
 
a180fd2
 
 
 
 
 
 
 
 
 
 
d43f961
 
 
 
 
 
 
 
 
 
a0194e7
d43f961
 
a0194e7
 
 
d43f961
 
a0194e7
d43f961
 
 
a0194e7
d43f961
 
 
 
 
 
 
a0194e7
d43f961
 
 
a0194e7
d43f961
 
 
 
 
 
 
 
a0194e7
 
 
 
 
 
 
 
d43f961
 
 
 
 
 
 
 
a0194e7
d43f961
 
 
 
 
a0194e7
a180fd2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
'''CRDT is used to synchronize workspace state for backend and frontend(s).'''
import asyncio
import contextlib
import fastapi
import os.path
import pycrdt
import pycrdt_websocket

import pycrdt_websocket.ystore

router = fastapi.APIRouter()

def ws_exception_handler(exception, log):
    print('exception', exception)
    log.exception(exception)
    return True

class WebsocketServer(pycrdt_websocket.WebsocketServer):
    async def init_room(self, name):
        ystore = pycrdt_websocket.ystore.FileYStore(f'crdt_data/{name}.crdt')
        ydoc = pycrdt.Doc()
        ydoc['workspace'] = ws = pycrdt.Map()
        # Replay updates from the store.
        try:
            for update, timestamp in [(item[0], item[-1]) async for item in ystore.read()]:
                ydoc.apply_update(update)
        except pycrdt_websocket.ystore.YDocNotFound:
            pass
        if 'nodes' not in ws:
            ws['nodes'] = pycrdt.Array()
        if 'edges' not in ws:
            ws['edges'] = pycrdt.Array()
        if 'env' not in ws:
            ws['env'] = 'unset'
            try_to_load_workspace(ws, name)
        room = pycrdt_websocket.YRoom(ystore=ystore, ydoc=ydoc)
        room.ws = ws
        def on_change(changes):
            asyncio.create_task(workspace_changed(changes, ws))
        ws.observe_deep(on_change)
        return room

    async def get_room(self, name: str) -> pycrdt_websocket.YRoom:
        if name not in self.rooms:
            self.rooms[name] = await self.init_room(name)
        room = self.rooms[name]
        await self.start_room(room)
        return room

websocket_server = WebsocketServer(exception_handler=ws_exception_handler, auto_clean_rooms=False)
asgi_server = pycrdt_websocket.ASGIServer(websocket_server)

last_ws_input = None
def clean_input(ws_pyd):
    for node in ws_pyd.nodes:
        node.data.display = None
        node.position.x = 0
        node.position.y = 0
        if node.model_extra:
            for key in list(node.model_extra.keys()):
                delattr(node, key)

def crdt_update(crdt_obj, python_obj, boxes=set()):
    if isinstance(python_obj, dict):
        for key, value in python_obj.items():
            if key in boxes:
                crdt_obj[key] = value
            elif isinstance(value, dict):
                if crdt_obj.get(key) is None:
                    crdt_obj[key] = pycrdt.Map()
                crdt_update(crdt_obj[key], value, boxes)
            elif isinstance(value, list):
                if crdt_obj.get(key) is None:
                    crdt_obj[key] = pycrdt.Array()
                crdt_update(crdt_obj[key], value, boxes)
            else:
                crdt_obj[key] = value
    elif isinstance(python_obj, list):
        for i, value in enumerate(python_obj):
            if isinstance(value, dict):
                if i >= len(crdt_obj):
                    crdt_obj.append(pycrdt.Map())
                crdt_update(crdt_obj[i], value, boxes)
            elif isinstance(value, list):
                if i >= len(crdt_obj):
                    crdt_obj.append(pycrdt.Array())
                crdt_update(crdt_obj[i], value, boxes)
            else:
                if i >= len(crdt_obj):
                    crdt_obj.append(value)
                else:
                    crdt_obj[i] = value
    else:
        raise ValueError('Invalid type:', python_obj)


def try_to_load_workspace(ws, name):
    from . import workspace
    json_path = f'data/{name}'
    if os.path.exists(json_path):
        ws_pyd = workspace.load(json_path)
        crdt_update(ws, ws_pyd.model_dump(), boxes={'display'})

async def workspace_changed(e, ws_crdt):
    global last_ws_input
    from . import workspace
    ws_pyd = workspace.Workspace.model_validate(ws_crdt.to_py())
    clean_input(ws_pyd)
    if ws_pyd == last_ws_input:
        return
    last_ws_input = ws_pyd.model_copy(deep=True)
    await workspace.execute(ws_pyd)
    for nc, np in zip(ws_crdt['nodes'], ws_pyd.nodes):
        if 'data' not in nc:
            nc['data'] = pycrdt.Map()
        # Display is added as an opaque Box.
        nc['data']['display'] = np.data.display
        nc['data']['error'] = np.data.error

@contextlib.asynccontextmanager
async def lifespan(app):
    async with websocket_server:
        yield

def sanitize_path(path):
    return os.path.relpath(os.path.normpath(os.path.join("/", path)), "/")

@router.websocket("/ws/crdt/{room_name}")
async def crdt_websocket(websocket: fastapi.WebSocket, room_name: str):
    room_name = sanitize_path(room_name)
    await asgi_server({'path': room_name}, websocket._receive, websocket._send)