Tai Truong
fix readme
d202ada
raw
history blame
3.02 kB
from typing import Any
import socketio
from loguru import logger
from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.deps import get_chat_service
from langflow.services.socket.utils import build_vertex, get_vertices
class SocketIOService(Service):
name = "socket_service"
def __init__(self, cache_service: CacheService | AsyncBaseCacheService):
self.cache_service = cache_service
def init(self, sio: socketio.AsyncServer) -> None:
# Registering event handlers
self.sio = sio
if self.sio:
self.sio.event(self.connect)
self.sio.event(self.disconnect)
self.sio.on("message")(self.message)
self.sio.on("get_vertices")(self.on_get_vertices)
self.sio.on("build_vertex")(self.on_build_vertex)
self.sessions = {} # type: dict[str, dict]
async def emit_error(self, sid, error) -> None:
await self.sio.emit("error", to=sid, data=error)
async def connect(self, sid, environ) -> None:
logger.info(f"Socket connected: {sid}")
self.sessions[sid] = environ
async def disconnect(self, sid) -> None:
logger.info(f"Socket disconnected: {sid}")
self.sessions.pop(sid, None)
async def message(self, sid, data=None) -> None:
# Logic for handling messages
await self.emit_message(to=sid, data=data or {"foo": "bar", "baz": [1, 2, 3]})
async def emit_message(self, to, data) -> None:
# Abstracting sio.emit
await self.sio.emit("message", to=to, data=data)
async def emit_token(self, to, data) -> None:
await self.sio.emit("token", to=to, data=data)
async def on_get_vertices(self, sid, flow_id) -> None:
await get_vertices(self.sio, sid, flow_id, get_chat_service())
async def on_build_vertex(self, sid, flow_id, vertex_id) -> None:
await build_vertex(
sio=self.sio,
sid=sid,
flow_id=flow_id,
vertex_id=vertex_id,
get_cache=self.get_cache,
set_cache=self.set_cache,
)
async def get_cache(self, sid: str) -> Any:
"""Get the cache for a client."""
value = self.cache_service.get(sid)
if isinstance(self.cache_service, AsyncBaseCacheService):
return await value
return value
async def set_cache(self, sid: str, build_result: Any) -> bool:
"""Set the cache for a client."""
# client_id is the flow id but that already exists in the cache
# so we need to change it to something else
result_dict = {
"result": build_result,
"type": type(build_result),
}
result = self.cache_service.upsert(sid, result_dict)
if isinstance(self.cache_service, AsyncBaseCacheService):
await result
return await self.cache_service.contains(sid)
return sid in self.cache_service