|
from operator import itemgetter |
|
import os |
|
from urllib import parse |
|
from pprint import pformat |
|
import socketio |
|
import time |
|
import logging |
|
from starlette.applications import Starlette |
|
from starlette.routing import Mount, Route |
|
from starlette.staticfiles import StaticFiles |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
from src.auth import google_auth_check |
|
from src.client import Client |
|
from src.context import ContextManager |
|
from src.transcriber import Transcriber |
|
|
|
from src.simuleval_agent_directory import NoAvailableAgentException |
|
from src.simuleval_agent_directory import SimulevalAgentDirectory |
|
from src.simuleval_transcoder import SimulevalTranscoder |
|
from src.transcoder_helpers import get_transcoder_output_events |
|
from src.logging import ( |
|
initialize_logger, |
|
catch_and_log_exceptions_for_sio_event_handlers, |
|
) |
|
|
|
logger = initialize_logger(__name__, level=logging.WARNING) |
|
print("=" * 20 + " ⭐️ Starting Server... ⭐️ " + "=" * 20) |
|
|
|
sio = socketio.AsyncServer( |
|
async_mode="asgi", |
|
cors_allowed_origins="*", |
|
logger=logger, |
|
|
|
) |
|
socketio_app = socketio.ASGIApp(sio) |
|
|
|
app_routes = [ |
|
Mount("/ws", app=socketio_app), |
|
] |
|
app = Starlette(debug=True, routes=app_routes) |
|
|
|
|
|
|
|
models_override = os.environ.get("MODELS_OVERRIDE") |
|
|
|
available_agents = SimulevalAgentDirectory() |
|
logger.info("Building and adding agents...") |
|
if models_override is not None: |
|
logger.info(f"MODELS_OVERRIDE supplied from env vars: {models_override}") |
|
available_agents.build_and_add_agents(models_override) |
|
|
|
agents_capabilities_for_json = available_agents.get_agents_capabilities_list_for_json() |
|
|
|
|
|
clients = {} |
|
|
|
|
|
@sio.on("connect") |
|
@catch_and_log_exceptions_for_sio_event_handlers(logger, sio) |
|
async def connect(sid, environ): |
|
logger.info(f"📥 [event: connected] sid={sid}") |
|
|
|
|
|
query_params = dict(parse.parse_qsl(environ["QUERY_STRING"])) |
|
client_id = query_params.get("clientID") |
|
token = query_params.get("token") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"query_params:\n{pformat(query_params)}") |
|
|
|
if client_id is None: |
|
logger.info("No clientID provided. Disconnecting...") |
|
await sio.disconnect(sid) |
|
return |
|
|
|
clients[sid] = Client(client_id) |
|
|
|
|
|
@sio.on("*") |
|
async def catch_all(event, sid, data): |
|
logger.info(f"[unhandled event: {event}] sid={sid} data={data}") |
|
|
|
|
|
@sio.event |
|
@catch_and_log_exceptions_for_sio_event_handlers(logger, sio) |
|
async def configure_stream(sid, config): |
|
client_obj = clients[sid] |
|
logger.warning(sid) |
|
|
|
if client_obj is None: |
|
logger.error(f"No client object for {sid}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": "member_or_room_is_none"} |
|
|
|
debug = config.get("debug") |
|
async_processing = config.get("async_processing") |
|
manual_transcribe = config.get("manual_transcribe") |
|
client_obj.manual_transcribe = manual_transcribe |
|
|
|
if manual_transcribe: |
|
client_obj.transcriber = Transcriber() |
|
client_obj.transcriber.start() |
|
else: |
|
|
|
model_type = config.get("model_type") |
|
client_obj.requested_output_type = model_type |
|
|
|
model_name = config.get("model_name") |
|
|
|
try: |
|
agent = available_agents.get_agent_or_throw(model_name) |
|
except NoAvailableAgentException as e: |
|
logger.warn(f"Error while getting agent: {e}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": str(e)} |
|
|
|
if client_obj.transcoder: |
|
logger.warn( |
|
"Member already has a transcoder configured. Closing it, and overwriting with a new transcoder..." |
|
) |
|
client_obj.transcoder.close = True |
|
|
|
t0 = time.time() |
|
try: |
|
client_obj.transcoder = SimulevalTranscoder( |
|
agent, |
|
config["rate"], |
|
debug=debug, |
|
buffer_limit=int(config["buffer_limit"]), |
|
) |
|
except Exception as e: |
|
logger.warn(f"Got exception while initializing agents: {e}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": str(e)} |
|
|
|
t1 = time.time() |
|
logger.debug(f"Booting up VAD and transcoder took {t1-t0} sec") |
|
|
|
|
|
if async_processing: |
|
client_obj.transcoder.start() |
|
|
|
client_obj.context = ContextManager() |
|
return {"status": "ok", "message": "server_ready"} |
|
|
|
|
|
@sio.on("set_dynamic_config") |
|
@catch_and_log_exceptions_for_sio_event_handlers(logger, sio) |
|
async def set_dynamic_config( |
|
sid, |
|
partial_config, |
|
): |
|
client_obj = clients[sid] |
|
|
|
if client_obj is None: |
|
logger.error(f"No client object for {sid}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": "member_or_room_is_none"} |
|
|
|
new_dynamic_config = { |
|
**(client_obj.transcoder_dynamic_config or {}), |
|
**partial_config, |
|
} |
|
logger.info( |
|
f"[set_dynamic_config] Setting new dynamic config:\n\n{pformat(new_dynamic_config)}\n" |
|
) |
|
|
|
client_obj.transcoder_dynamic_config = new_dynamic_config |
|
|
|
if client_obj.context: |
|
client_obj.context.set_language(partial_config["targetLanguage"]) |
|
|
|
|
|
|
|
return {"status": "ok", "message": "dynamic_config_set"} |
|
|
|
|
|
@sio.event |
|
async def incoming_audio(sid, blob): |
|
client_obj = clients[sid] |
|
|
|
if client_obj is None: |
|
logger.error(f"No client object for {sid}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": "member_or_room_is_none"} |
|
|
|
if client_obj.manual_transcribe: |
|
client_obj.transcriber.send_audio(blob) |
|
else: |
|
|
|
if not isinstance(blob, bytes): |
|
logger.error( |
|
f"[incoming_audio] Received audio from {sid}, but it was not of type `bytes`. type(blob) = {type(blob)}" |
|
) |
|
return |
|
|
|
if client_obj.transcoder is None: |
|
logger.error( |
|
f"[incoming_audio] Received audio from {sid}, but no transcoder configured to process it (member.transcoder is None). This should not happen." |
|
) |
|
return |
|
|
|
client_obj.transcoder.process_incoming_bytes( |
|
blob, dynamic_config=client_obj.transcoder_dynamic_config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
events = get_transcoder_output_events(client_obj.transcoder) |
|
logger.debug(f"[incoming_audio] transcoder output events: {len(events)}") |
|
|
|
if len(events) == 0: |
|
logger.debug("[incoming_audio] No transcoder output to send") |
|
else: |
|
for e in events: |
|
if e[ |
|
"event" |
|
] == "translation_speech" and client_obj.requested_output_type in [ |
|
"s2s", |
|
"s2s&t", |
|
]: |
|
logger.debug("[incoming_audio] Sending translation_speech event") |
|
await sio.emit("translation_speech", e, room=sid) |
|
elif e[ |
|
"event" |
|
] == "translation_text" and client_obj.requested_output_type in [ |
|
"s2t", |
|
"s2s&t", |
|
]: |
|
logger.debug("[incoming_audio] Sending translation_text event") |
|
await sio.emit("translation_text", e, room=sid) |
|
client_obj.context.add_text_chunk(e["payload"]) |
|
else: |
|
logger.error( |
|
f"[incoming_audio] Unexpected event type: {e['event']}" |
|
) |
|
new_context = client_obj.context.get_current_context() |
|
if new_context: |
|
await sio.emit( |
|
"context", |
|
{"event": "context", "payload": new_context}, |
|
room=sid, |
|
) |
|
return |
|
|
|
|
|
@sio.event |
|
async def stop_stream(sid): |
|
client_obj = clients[sid] |
|
|
|
if client_obj is None: |
|
logger.error(f"No client object for {sid}") |
|
await sio.disconnect(sid) |
|
return {"status": "error", "message": "member_or_room_is_none"} |
|
|
|
if client_obj.transcoder: |
|
client_obj.transcoder.close = True |
|
client_obj.transcoder = None |
|
|
|
if client_obj.transcriber: |
|
client_obj.transcriber.close_connection() |
|
|
|
|
|
@sio.event |
|
async def disconnect(sid): |
|
client_obj = clients[sid] |
|
if client_obj is None: |
|
return |
|
|
|
if client_obj.transcriber: |
|
client_obj.transcriber.stop() |
|
|
|
if client_obj.transcoder: |
|
client_obj.transcoder.close = True |
|
client_obj.transcoder = None |
|
|
|
del clients[sid] |
|
|