File size: 4,692 Bytes
cc7c705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b731f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc7c705
 
9b731f8
cc7c705
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# handler.py
import asyncio
import base64
import json
import os
import traceback
from websockets.asyncio.client import connect

host = "generativelanguage.googleapis.com"
model = "gemini-2.0-flash-exp"
api_key = os.environ["GOOGLE_API_KEY"]
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"

class AudioLoop:
    def __init__(self):
        self.ws = None
        # Queue for messages to be sent *to* Gemini
        self.out_queue = asyncio.Queue()
        # Queue for PCM audio received *from* Gemini
        self.audio_in_queue = asyncio.Queue()

    async def startup(self, tools=None):
        """Send the model setup message to Gemini.
        
        Args:
            tools: Optional list of tools to enable for the model
        """
        setup_msg = {"setup": {"model": f"models/{model}"}}
        if tools:
            setup_msg["setup"]["tools"] = tools
            
        await self.ws.send(json.dumps(setup_msg))

        raw_response = await self.ws.recv()
        setup_response = json.loads(raw_response)
        print("[AudioLoop] Setup response from Gemini:", setup_response)

    async def send_realtime(self):
        """Read from out_queue and forward those messages to Gemini in real time."""
        while True:
            msg = await self.out_queue.get()
            await self.ws.send(json.dumps(msg))

    async def receive_audio(self):
        """Read from Gemini websocket and push PCM data into audio_in_queue."""
        async for raw_response in self.ws:
            response = json.loads(raw_response)
            # Debug log all responses (optional)
            # print("Gemini raw response:", response)

            # Check if there's inline PCM data
            try:
                b64data = (
                    response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
                )
                pcm_data = base64.b64decode(b64data)
                await self.audio_in_queue.put(pcm_data)
            except KeyError:
                # No audio in this message
                pass

            tool_call = response.pop('toolCall', None)
            if tool_call is not None:
                await self.handle_tool_call(tool_call)

            # If "turnComplete" is present
            if "serverContent" in response and response["serverContent"].get("turnComplete"):
                print("[AudioLoop] Gemini turn complete")
    
    async def handle_tool_call(self,tool_call):
        print("    ", tool_call)
        for fc in tool_call['functionCalls']:
            msg = {
            'tool_response': {
                'function_responses': [{
                    'id': fc['id'],
                    'name': fc['name'],
                    'response':{'result': {'string_value': 'ok'}}
                }]
                }
            }
            print('>>> ', msg)
            await self.ws.send(json.dumps(msg))

    async def run(self):
        """Main entry point: connects to Gemini, starts send/receive tasks."""
        try:
            turn_on_the_lights_schema = {'name': 'turn_on_the_lights'}
            turn_off_the_lights_schema = {'name': 'turn_off_the_lights'}
            tools = [
                {'google_search': {}},
                {'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]},
                {'code_execution': {}},
            ]
            async with connect(uri, additional_headers={"Content-Type": "application/json"}) as ws:
                self.ws = ws
                await self.startup(tools)

                try:
                    async with asyncio.TaskGroup() as tg:
                        send_task = tg.create_task(self.send_realtime())
                        receive_task = tg.create_task(self.receive_audio())
                        await asyncio.Future()  # Keep running until canceled
                finally:
                    # Clean up tasks and connection
                    if 'send_task' in locals():
                        send_task.cancel()
                    if 'receive_task' in locals():
                        receive_task.cancel()                        
                    try:
                        await self.ws.close()
                        print("[AudioLoop] Closed WebSocket connection")
                    except Exception as e:
                        print(f"[AudioLoop] Error closing Gemini connection: {e}")
                    print("[AudioLoop] Cleanup complete")

        except asyncio.CancelledError:
            print("[AudioLoop] Cancelled")
        except Exception as e:
            traceback.print_exc()
            raise