File size: 5,622 Bytes
cc7c705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b731f8
 
cc7c705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b731f8
cc7c705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265e66c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# webapp.py

import asyncio
import base64
import json
import os

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import uvicorn

from handler import AudioLoop  # Import your AudioLoop from above

app = FastAPI()

# Mount the web_ui directory to serve static files
current_dir = os.path.dirname(os.path.realpath(__file__))
app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui")

@app.get("/")
async def get_index():
    # Read and return the index.html file
    index_path = os.path.join(current_dir, "index.html")
    with open(index_path, "r", encoding="utf-8") as f:
        html_content = f.read()
    return HTMLResponse(content=html_content)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    print("[websocket_endpoint] Client connected.")

    # Create a new AudioLoop instance for this client
    audio_loop = AudioLoop()
    audio_ordering_buffer = {}
    expected_audio_seq = 0

    # Start the AudioLoop for this client
    loop_task = asyncio.create_task(audio_loop.run())
    print("[websocket_endpoint] Started new AudioLoop for client")

    async def from_client_to_gemini():
        """Handles incoming messages from the client and forwards them to Gemini."""
        nonlocal audio_ordering_buffer, expected_audio_seq
        try:
            while True:
                data = await websocket.receive_text()
                msg = json.loads(data)
                msg_type = msg.get("type")

                #print("[from_client_to_gemini] Received message from client:", msg)

                # Handle audio data from client
                if msg_type == "audio":
                    raw_pcm = base64.b64decode(msg["payload"])
                    forward_msg = {
                        "realtime_input": {
                            "media_chunks": [
                                {
                                    "data": base64.b64encode(raw_pcm).decode(),
                                    "mime_type": "audio/pcm"
                                }
                            ]
                        }
                    }
                    # Retrieve the sequence number from the message
                    seq = msg.get("seq")
                    if seq is not None:
                        # Store the message in the buffer
                        audio_ordering_buffer[seq] = forward_msg
                        # Forward any messages in order
                        while expected_audio_seq in audio_ordering_buffer:
                            msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
                            await audio_loop.out_queue.put(msg_to_forward)
                            expected_audio_seq += 1
                    else:
                        # If no sequence number is provided, forward immediately
                        await audio_loop.out_queue.put(forward_msg)

                # Handle text data from client
                elif msg_type == "text":
                    user_text = msg.get("content", "")
                    print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text)
                    forward_msg = {
                        "client_content": {
                            "turn_complete": True,
                            "turns": [
                                {
                                    "role": "user",
                                    "parts": [
                                        {"text": user_text}
                                    ]
                                }
                            ]
                        }
                    }
                    await audio_loop.out_queue.put(forward_msg)

                else:
                    print("[from_client_to_gemini] Unknown message type:", msg_type)

        except WebSocketDisconnect:
            print("[from_client_to_gemini] Client disconnected.")
            #del audio_loop
            loop_task.cancel()
        except Exception as e:
            print("[from_client_to_gemini] Error:", e)

    async def from_gemini_to_client():
        """Reads PCM audio from Gemini and sends it back to the client."""
        try:
            while True:
                pcm_data = await audio_loop.audio_in_queue.get()
                b64_pcm = base64.b64encode(pcm_data).decode()

                out_msg = {
                    "type": "audio",
                    "payload": b64_pcm
                }
                print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data))
                await websocket.send_text(json.dumps(out_msg))

        except WebSocketDisconnect:
            print("[from_gemini_to_client] Client disconnected.")
            audio_loop.stop()
        except Exception as e:
            print("[from_gemini_to_client] Error:", e)

    # Launch both tasks concurrently. If either fails or disconnects, we exit.
    try:
        await asyncio.gather(
            from_client_to_gemini(),
            from_gemini_to_client(),
        )
    finally:
        print("[websocket_endpoint] WebSocket handler finished.")
        # Clean up the AudioLoop when the client disconnects
        loop_task.cancel()
        try:
            await loop_task
        except asyncio.CancelledError:
            pass
        print("[websocket_endpoint] Cleaned up AudioLoop for client")

if __name__ == "__main__":
    uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)