gemini-live / handler.py
Nirav Madhani
Disconnect logic and UI fix
9b731f8
# handler.py
import asyncio
import base64
import json
import os
import traceback
from websockets.asyncio.client import connect
host = "generativelanguage.googleapis.com"
model = "gemini-2.0-flash-exp"
api_key = os.environ["GOOGLE_API_KEY"]
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"
class AudioLoop:
def __init__(self):
self.ws = None
# Queue for messages to be sent *to* Gemini
self.out_queue = asyncio.Queue()
# Queue for PCM audio received *from* Gemini
self.audio_in_queue = asyncio.Queue()
async def startup(self, tools=None):
"""Send the model setup message to Gemini.
Args:
tools: Optional list of tools to enable for the model
"""
setup_msg = {"setup": {"model": f"models/{model}"}}
if tools:
setup_msg["setup"]["tools"] = tools
await self.ws.send(json.dumps(setup_msg))
raw_response = await self.ws.recv()
setup_response = json.loads(raw_response)
print("[AudioLoop] Setup response from Gemini:", setup_response)
async def send_realtime(self):
"""Read from out_queue and forward those messages to Gemini in real time."""
while True:
msg = await self.out_queue.get()
await self.ws.send(json.dumps(msg))
async def receive_audio(self):
"""Read from Gemini websocket and push PCM data into audio_in_queue."""
async for raw_response in self.ws:
response = json.loads(raw_response)
# Debug log all responses (optional)
# print("Gemini raw response:", response)
# Check if there's inline PCM data
try:
b64data = (
response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
)
pcm_data = base64.b64decode(b64data)
await self.audio_in_queue.put(pcm_data)
except KeyError:
# No audio in this message
pass
tool_call = response.pop('toolCall', None)
if tool_call is not None:
await self.handle_tool_call(tool_call)
# If "turnComplete" is present
if "serverContent" in response and response["serverContent"].get("turnComplete"):
print("[AudioLoop] Gemini turn complete")
async def handle_tool_call(self,tool_call):
print(" ", tool_call)
for fc in tool_call['functionCalls']:
msg = {
'tool_response': {
'function_responses': [{
'id': fc['id'],
'name': fc['name'],
'response':{'result': {'string_value': 'ok'}}
}]
}
}
print('>>> ', msg)
await self.ws.send(json.dumps(msg))
async def run(self):
"""Main entry point: connects to Gemini, starts send/receive tasks."""
try:
turn_on_the_lights_schema = {'name': 'turn_on_the_lights'}
turn_off_the_lights_schema = {'name': 'turn_off_the_lights'}
tools = [
{'google_search': {}},
{'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]},
{'code_execution': {}},
]
async with connect(uri, additional_headers={"Content-Type": "application/json"}) as ws:
self.ws = ws
await self.startup(tools)
try:
async with asyncio.TaskGroup() as tg:
send_task = tg.create_task(self.send_realtime())
receive_task = tg.create_task(self.receive_audio())
await asyncio.Future() # Keep running until canceled
finally:
# Clean up tasks and connection
if 'send_task' in locals():
send_task.cancel()
if 'receive_task' in locals():
receive_task.cancel()
try:
await self.ws.close()
print("[AudioLoop] Closed WebSocket connection")
except Exception as e:
print(f"[AudioLoop] Error closing Gemini connection: {e}")
print("[AudioLoop] Cleanup complete")
except asyncio.CancelledError:
print("[AudioLoop] Cancelled")
except Exception as e:
traceback.print_exc()
raise