Spaces:
Running
Running
# webapp.py | |
import asyncio | |
import base64 | |
import json | |
import os | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
import uvicorn | |
from handler import AudioLoop # Import your AudioLoop from above | |
app = FastAPI() | |
# Mount the web_ui directory to serve static files | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui") | |
async def get_index(): | |
# Read and return the index.html file | |
index_path = os.path.join(current_dir, "index.html") | |
with open(index_path, "r", encoding="utf-8") as f: | |
html_content = f.read() | |
return HTMLResponse(content=html_content) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
print("[websocket_endpoint] Client connected.") | |
# Create a new AudioLoop instance for this client | |
audio_loop = AudioLoop() | |
audio_ordering_buffer = {} | |
expected_audio_seq = 0 | |
# Start the AudioLoop for this client | |
loop_task = asyncio.create_task(audio_loop.run()) | |
print("[websocket_endpoint] Started new AudioLoop for client") | |
async def from_client_to_gemini(): | |
"""Handles incoming messages from the client and forwards them to Gemini.""" | |
nonlocal audio_ordering_buffer, expected_audio_seq | |
try: | |
while True: | |
data = await websocket.receive_text() | |
msg = json.loads(data) | |
msg_type = msg.get("type") | |
#print("[from_client_to_gemini] Received message from client:", msg) | |
# Handle audio data from client | |
if msg_type == "audio": | |
raw_pcm = base64.b64decode(msg["payload"]) | |
forward_msg = { | |
"realtime_input": { | |
"media_chunks": [ | |
{ | |
"data": base64.b64encode(raw_pcm).decode(), | |
"mime_type": "audio/pcm" | |
} | |
] | |
} | |
} | |
# Retrieve the sequence number from the message | |
seq = msg.get("seq") | |
if seq is not None: | |
# Store the message in the buffer | |
audio_ordering_buffer[seq] = forward_msg | |
# Forward any messages in order | |
while expected_audio_seq in audio_ordering_buffer: | |
msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq) | |
await audio_loop.out_queue.put(msg_to_forward) | |
expected_audio_seq += 1 | |
else: | |
# If no sequence number is provided, forward immediately | |
await audio_loop.out_queue.put(forward_msg) | |
# Handle text data from client | |
elif msg_type == "text": | |
user_text = msg.get("content", "") | |
print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text) | |
forward_msg = { | |
"client_content": { | |
"turn_complete": True, | |
"turns": [ | |
{ | |
"role": "user", | |
"parts": [ | |
{"text": user_text} | |
] | |
} | |
] | |
} | |
} | |
await audio_loop.out_queue.put(forward_msg) | |
else: | |
print("[from_client_to_gemini] Unknown message type:", msg_type) | |
except WebSocketDisconnect: | |
print("[from_client_to_gemini] Client disconnected.") | |
#del audio_loop | |
loop_task.cancel() | |
except Exception as e: | |
print("[from_client_to_gemini] Error:", e) | |
async def from_gemini_to_client(): | |
"""Reads PCM audio from Gemini and sends it back to the client.""" | |
try: | |
while True: | |
pcm_data = await audio_loop.audio_in_queue.get() | |
b64_pcm = base64.b64encode(pcm_data).decode() | |
out_msg = { | |
"type": "audio", | |
"payload": b64_pcm | |
} | |
print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data)) | |
await websocket.send_text(json.dumps(out_msg)) | |
except WebSocketDisconnect: | |
print("[from_gemini_to_client] Client disconnected.") | |
audio_loop.stop() | |
except Exception as e: | |
print("[from_gemini_to_client] Error:", e) | |
# Launch both tasks concurrently. If either fails or disconnects, we exit. | |
try: | |
await asyncio.gather( | |
from_client_to_gemini(), | |
from_gemini_to_client(), | |
) | |
finally: | |
print("[websocket_endpoint] WebSocket handler finished.") | |
# Clean up the AudioLoop when the client disconnects | |
loop_task.cancel() | |
try: | |
await loop_task | |
except asyncio.CancelledError: | |
pass | |
print("[websocket_endpoint] Cleaned up AudioLoop for client") | |
if __name__ == "__main__": | |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True) | |