gemini-live / webapp.py
Nirav Madhani
Disconnect logic and UI fix
9b731f8
# webapp.py
import asyncio
import base64
import json
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
from handler import AudioLoop # Import your AudioLoop from above
app = FastAPI()
# Mount the web_ui directory to serve static files
current_dir = os.path.dirname(os.path.realpath(__file__))
app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui")
@app.get("/")
async def get_index():
# Read and return the index.html file
index_path = os.path.join(current_dir, "index.html")
with open(index_path, "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("[websocket_endpoint] Client connected.")
# Create a new AudioLoop instance for this client
audio_loop = AudioLoop()
audio_ordering_buffer = {}
expected_audio_seq = 0
# Start the AudioLoop for this client
loop_task = asyncio.create_task(audio_loop.run())
print("[websocket_endpoint] Started new AudioLoop for client")
async def from_client_to_gemini():
"""Handles incoming messages from the client and forwards them to Gemini."""
nonlocal audio_ordering_buffer, expected_audio_seq
try:
while True:
data = await websocket.receive_text()
msg = json.loads(data)
msg_type = msg.get("type")
#print("[from_client_to_gemini] Received message from client:", msg)
# Handle audio data from client
if msg_type == "audio":
raw_pcm = base64.b64decode(msg["payload"])
forward_msg = {
"realtime_input": {
"media_chunks": [
{
"data": base64.b64encode(raw_pcm).decode(),
"mime_type": "audio/pcm"
}
]
}
}
# Retrieve the sequence number from the message
seq = msg.get("seq")
if seq is not None:
# Store the message in the buffer
audio_ordering_buffer[seq] = forward_msg
# Forward any messages in order
while expected_audio_seq in audio_ordering_buffer:
msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
await audio_loop.out_queue.put(msg_to_forward)
expected_audio_seq += 1
else:
# If no sequence number is provided, forward immediately
await audio_loop.out_queue.put(forward_msg)
# Handle text data from client
elif msg_type == "text":
user_text = msg.get("content", "")
print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text)
forward_msg = {
"client_content": {
"turn_complete": True,
"turns": [
{
"role": "user",
"parts": [
{"text": user_text}
]
}
]
}
}
await audio_loop.out_queue.put(forward_msg)
else:
print("[from_client_to_gemini] Unknown message type:", msg_type)
except WebSocketDisconnect:
print("[from_client_to_gemini] Client disconnected.")
#del audio_loop
loop_task.cancel()
except Exception as e:
print("[from_client_to_gemini] Error:", e)
async def from_gemini_to_client():
"""Reads PCM audio from Gemini and sends it back to the client."""
try:
while True:
pcm_data = await audio_loop.audio_in_queue.get()
b64_pcm = base64.b64encode(pcm_data).decode()
out_msg = {
"type": "audio",
"payload": b64_pcm
}
print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data))
await websocket.send_text(json.dumps(out_msg))
except WebSocketDisconnect:
print("[from_gemini_to_client] Client disconnected.")
audio_loop.stop()
except Exception as e:
print("[from_gemini_to_client] Error:", e)
# Launch both tasks concurrently. If either fails or disconnects, we exit.
try:
await asyncio.gather(
from_client_to_gemini(),
from_gemini_to_client(),
)
finally:
print("[websocket_endpoint] WebSocket handler finished.")
# Clean up the AudioLoop when the client disconnects
loop_task.cancel()
try:
await loop_task
except asyncio.CancelledError:
pass
print("[websocket_endpoint] Cleaned up AudioLoop for client")
if __name__ == "__main__":
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)