|
""" |
|
async with websockets.connect( # type: ignore |
|
url, |
|
extra_headers={ |
|
"api-key": api_key, # type: ignore |
|
}, |
|
) as backend_ws: |
|
forward_task = asyncio.create_task( |
|
forward_messages(websocket, backend_ws) |
|
) |
|
|
|
try: |
|
while True: |
|
message = await websocket.receive_text() |
|
await backend_ws.send(message) |
|
except websockets.exceptions.ConnectionClosed: # type: ignore |
|
forward_task.cancel() |
|
finally: |
|
if not forward_task.done(): |
|
forward_task.cancel() |
|
try: |
|
await forward_task |
|
except asyncio.CancelledError: |
|
pass |
|
""" |
|
|
|
import asyncio |
|
import concurrent.futures |
|
import json |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import litellm |
|
|
|
from .litellm_logging import Logging as LiteLLMLogging |
|
|
|
|
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) |
|
|
|
DefaultLoggedRealTimeEventTypes = [ |
|
"session.created", |
|
"response.create", |
|
"response.done", |
|
] |
|
|
|
|
|
class RealTimeStreaming: |
|
def __init__( |
|
self, |
|
websocket: Any, |
|
backend_ws: Any, |
|
logging_obj: Optional[LiteLLMLogging] = None, |
|
): |
|
self.websocket = websocket |
|
self.backend_ws = backend_ws |
|
self.logging_obj = logging_obj |
|
self.messages: List = [] |
|
self.input_message: Dict = {} |
|
|
|
_logged_real_time_event_types = litellm.logged_real_time_event_types |
|
|
|
if _logged_real_time_event_types is None: |
|
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes |
|
self.logged_real_time_event_types = _logged_real_time_event_types |
|
|
|
def _should_store_message(self, message: Union[str, bytes]) -> bool: |
|
if isinstance(message, bytes): |
|
message = message.decode("utf-8") |
|
message_obj = json.loads(message) |
|
_msg_type = message_obj["type"] |
|
if self.logged_real_time_event_types == "*": |
|
return True |
|
if _msg_type in self.logged_real_time_event_types: |
|
return True |
|
return False |
|
|
|
def store_message(self, message: Union[str, bytes]): |
|
"""Store message in list""" |
|
if self._should_store_message(message): |
|
self.messages.append(message) |
|
|
|
def store_input(self, message: dict): |
|
"""Store input message""" |
|
self.input_message = message |
|
if self.logging_obj: |
|
self.logging_obj.pre_call(input=message, api_key="") |
|
|
|
async def log_messages(self): |
|
"""Log messages in list""" |
|
if self.logging_obj: |
|
|
|
|
|
asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) |
|
|
|
executor.submit(self.logging_obj.success_handler(self.messages)) |
|
|
|
async def backend_to_client_send_messages(self): |
|
import websockets |
|
|
|
try: |
|
while True: |
|
message = await self.backend_ws.recv() |
|
await self.websocket.send_text(message) |
|
|
|
|
|
self.store_message(message) |
|
except websockets.exceptions.ConnectionClosed: |
|
pass |
|
except Exception: |
|
pass |
|
finally: |
|
await self.log_messages() |
|
|
|
async def client_ack_messages(self): |
|
try: |
|
while True: |
|
message = await self.websocket.receive_text() |
|
|
|
self.store_input(message=message) |
|
|
|
await self.backend_ws.send(message) |
|
except self.websockets.exceptions.ConnectionClosed: |
|
pass |
|
|
|
async def bidirectional_forward(self): |
|
|
|
forward_task = asyncio.create_task(self.backend_to_client_send_messages()) |
|
try: |
|
await self.client_ack_messages() |
|
except self.websockets.exceptions.ConnectionClosed: |
|
forward_task.cancel() |
|
finally: |
|
if not forward_task.done(): |
|
forward_task.cancel() |
|
try: |
|
await forward_task |
|
except asyncio.CancelledError: |
|
pass |
|
|