Spaces:
Running
Running
File size: 5,622 Bytes
cc7c705 9b731f8 cc7c705 9b731f8 cc7c705 265e66c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# webapp.py
import asyncio
import base64
import json
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
from handler import AudioLoop # Import your AudioLoop from above
app = FastAPI()
# Mount the web_ui directory to serve static files
current_dir = os.path.dirname(os.path.realpath(__file__))
app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui")
@app.get("/")
async def get_index():
# Read and return the index.html file
index_path = os.path.join(current_dir, "index.html")
with open(index_path, "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("[websocket_endpoint] Client connected.")
# Create a new AudioLoop instance for this client
audio_loop = AudioLoop()
audio_ordering_buffer = {}
expected_audio_seq = 0
# Start the AudioLoop for this client
loop_task = asyncio.create_task(audio_loop.run())
print("[websocket_endpoint] Started new AudioLoop for client")
async def from_client_to_gemini():
"""Handles incoming messages from the client and forwards them to Gemini."""
nonlocal audio_ordering_buffer, expected_audio_seq
try:
while True:
data = await websocket.receive_text()
msg = json.loads(data)
msg_type = msg.get("type")
#print("[from_client_to_gemini] Received message from client:", msg)
# Handle audio data from client
if msg_type == "audio":
raw_pcm = base64.b64decode(msg["payload"])
forward_msg = {
"realtime_input": {
"media_chunks": [
{
"data": base64.b64encode(raw_pcm).decode(),
"mime_type": "audio/pcm"
}
]
}
}
# Retrieve the sequence number from the message
seq = msg.get("seq")
if seq is not None:
# Store the message in the buffer
audio_ordering_buffer[seq] = forward_msg
# Forward any messages in order
while expected_audio_seq in audio_ordering_buffer:
msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
await audio_loop.out_queue.put(msg_to_forward)
expected_audio_seq += 1
else:
# If no sequence number is provided, forward immediately
await audio_loop.out_queue.put(forward_msg)
# Handle text data from client
elif msg_type == "text":
user_text = msg.get("content", "")
print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text)
forward_msg = {
"client_content": {
"turn_complete": True,
"turns": [
{
"role": "user",
"parts": [
{"text": user_text}
]
}
]
}
}
await audio_loop.out_queue.put(forward_msg)
else:
print("[from_client_to_gemini] Unknown message type:", msg_type)
except WebSocketDisconnect:
print("[from_client_to_gemini] Client disconnected.")
#del audio_loop
loop_task.cancel()
except Exception as e:
print("[from_client_to_gemini] Error:", e)
async def from_gemini_to_client():
"""Reads PCM audio from Gemini and sends it back to the client."""
try:
while True:
pcm_data = await audio_loop.audio_in_queue.get()
b64_pcm = base64.b64encode(pcm_data).decode()
out_msg = {
"type": "audio",
"payload": b64_pcm
}
print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data))
await websocket.send_text(json.dumps(out_msg))
except WebSocketDisconnect:
print("[from_gemini_to_client] Client disconnected.")
audio_loop.stop()
except Exception as e:
print("[from_gemini_to_client] Error:", e)
# Launch both tasks concurrently. If either fails or disconnects, we exit.
try:
await asyncio.gather(
from_client_to_gemini(),
from_gemini_to_client(),
)
finally:
print("[websocket_endpoint] WebSocket handler finished.")
# Clean up the AudioLoop when the client disconnects
loop_task.cancel()
try:
await loop_task
except asyncio.CancelledError:
pass
print("[websocket_endpoint] Cleaned up AudioLoop for client")
if __name__ == "__main__":
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
|