TestLLM / litellm /litellm_core_utils /realtime_streaming.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
4.55 kB
"""
async with websockets.connect( # type: ignore
url,
extra_headers={
"api-key": api_key, # type: ignore
},
) as backend_ws:
forward_task = asyncio.create_task(
forward_messages(websocket, backend_ws)
)
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
"""
import asyncio
import concurrent.futures
import json
from typing import Any, Dict, List, Optional, Union
import litellm
from .litellm_logging import Logging as LiteLLMLogging
# Create a thread pool with a maximum of 10 threads
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
DefaultLoggedRealTimeEventTypes = [
"session.created",
"response.create",
"response.done",
]
class RealTimeStreaming:
def __init__(
self,
websocket: Any,
backend_ws: Any,
logging_obj: Optional[LiteLLMLogging] = None,
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List = []
self.input_message: Dict = {}
_logged_real_time_event_types = litellm.logged_real_time_event_types
if _logged_real_time_event_types is None:
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
self.logged_real_time_event_types = _logged_real_time_event_types
def _should_store_message(self, message: Union[str, bytes]) -> bool:
if isinstance(message, bytes):
message = message.decode("utf-8")
message_obj = json.loads(message)
_msg_type = message_obj["type"]
if self.logged_real_time_event_types == "*":
return True
if _msg_type in self.logged_real_time_event_types:
return True
return False
def store_message(self, message: Union[str, bytes]):
"""Store message in list"""
if self._should_store_message(message):
self.messages.append(message)
def store_input(self, message: dict):
"""Store input message"""
self.input_message = message
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def log_messages(self):
"""Log messages in list"""
if self.logging_obj:
## ASYNC LOGGING
# Create an event loop for the new thread
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
## SYNC LOGGING
executor.submit(self.logging_obj.success_handler(self.messages))
async def backend_to_client_send_messages(self):
import websockets
try:
while True:
message = await self.backend_ws.recv()
await self.websocket.send_text(message)
## LOGGING
self.store_message(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
except Exception:
pass
finally:
await self.log_messages()
async def client_ack_messages(self):
try:
while True:
message = await self.websocket.receive_text()
## LOGGING
self.store_input(message=message)
## FORWARD TO BACKEND
await self.backend_ws.send(message)
except self.websockets.exceptions.ConnectionClosed: # type: ignore
pass
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()
except self.websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass