Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2023 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
## Setup | |
To install the dependencies for this script, run: | |
``` | |
pip install google-genai opencv-python pyaudio pillow mss | |
``` | |
Before running this script, ensure the `GOOGLE_API_KEY` environment | |
variable is set to the api-key you obtained from Google AI Studio. | |
Important: **Use headphones**. This script uses the system default audio | |
input and output, which often won't include echo cancellation. So to prevent | |
the model from interrupting itself it is important that you use headphones. | |
## Run | |
To run the script: | |
``` | |
python live_api_starter.py | |
``` | |
The script takes a video-mode flag `--mode`, this can be "camera", "screen", or "none". | |
The default is "camera". To share your screen run: | |
``` | |
python live_api_starter.py --mode screen | |
``` | |
""" | |
import asyncio | |
import base64 | |
import json | |
import io | |
import os | |
import sys | |
import traceback | |
import cv2 | |
import pyaudio | |
import PIL.Image | |
import mss | |
import argparse | |
from websockets.asyncio.client import connect | |
if sys.version_info < (3, 11, 0): | |
import taskgroup, exceptiongroup | |
asyncio.TaskGroup = taskgroup.TaskGroup | |
asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup | |
FORMAT = pyaudio.paInt16 | |
CHANNELS = 1 | |
SEND_SAMPLE_RATE = 16000 | |
RECEIVE_SAMPLE_RATE = 24000 | |
CHUNK_SIZE = 512 | |
host = "generativelanguage.googleapis.com" | |
model = "gemini-2.0-flash-exp" | |
DEFAULT_MODE="none" | |
api_key = os.environ["GOOGLE_API_KEY"] | |
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}" | |
class AudioLoop: | |
def __init__(self, video_mode=DEFAULT_MODE): | |
self.video_mode=video_mode | |
self.audio_in_queue = None | |
self.out_queue = None | |
self.ws = None | |
self.audio_stream = None | |
async def startup(self): | |
setup_msg = {"setup": {"model": f"models/{model}"}} | |
await self.ws.send(json.dumps(setup_msg)) | |
raw_response = await self.ws.recv(decode=False) | |
setup_response = json.loads(raw_response.decode("ascii")) | |
async def send_text(self): | |
while True: | |
text = await asyncio.to_thread(input, "message > ") | |
if text.lower() == "q": | |
break | |
msg = { | |
"client_content": { | |
"turn_complete": True, | |
"turns": [{"role": "user", "parts": [{"text": text}]}], | |
} | |
} | |
await self.ws.send(json.dumps(msg)) | |
def _get_frame(self, cap): | |
# Read the frame | |
ret, frame = cap.read() | |
# Check if the frame was read successfully | |
if not ret: | |
return None | |
# Fix: Convert BGR to RGB color space | |
# OpenCV captures in BGR but PIL expects RGB format | |
# This prevents the blue tint in the video feed | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
img = PIL.Image.fromarray(frame_rgb) # Now using RGB frame | |
img.thumbnail([1024, 1024]) | |
image_io = io.BytesIO() | |
img.save(image_io, format="jpeg") | |
image_io.seek(0) | |
mime_type = "image/jpeg" | |
image_bytes = image_io.read() | |
return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} | |
async def get_frames(self): | |
# This takes about a second, and will block the whole program | |
# causing the audio pipeline to overflow if you don't to_thread it. | |
cap = await asyncio.to_thread( | |
cv2.VideoCapture, 0 | |
) # 0 represents the default camera | |
while True: | |
frame = await asyncio.to_thread(self._get_frame, cap) | |
if frame is None: | |
break | |
await asyncio.sleep(1.0) | |
msg = {"realtime_input": {"media_chunks": [frame]}} | |
await self.out_queue.put(msg) | |
# Release the VideoCapture object | |
cap.release() | |
def _get_screen(self): | |
sct = mss.mss() | |
monitor = sct.monitors[0] | |
i = sct.grab(monitor) | |
mime_type = "image/jpeg" | |
image_bytes = mss.tools.to_png(i.rgb, i.size) | |
img = PIL.Image.open(io.BytesIO(image_bytes)) | |
image_io = io.BytesIO() | |
img.save(image_io, format="jpeg") | |
image_io.seek(0) | |
image_bytes = image_io.read() | |
return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} | |
async def get_screen(self): | |
while True: | |
frame = await asyncio.to_thread(self._get_screen) | |
if frame is None: | |
break | |
await asyncio.sleep(1.0) | |
msg = {"realtime_input": {"media_chunks": frame}} | |
await self.out_queue.put(msg) | |
async def send_realtime(self): | |
while True: | |
msg = await self.out_queue.get() | |
await self.ws.send(json.dumps(msg)) | |
async def listen_audio(self): | |
pya = pyaudio.PyAudio() | |
mic_info = pya.get_default_input_device_info() | |
self.audio_stream = pya.open( | |
format=FORMAT, | |
channels=CHANNELS, | |
rate=SEND_SAMPLE_RATE, | |
input=True, | |
input_device_index=mic_info["index"], | |
frames_per_buffer=CHUNK_SIZE, | |
) | |
while True: | |
data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE) | |
msg = { | |
"realtime_input": { | |
"media_chunks": [ | |
{ | |
"data": base64.b64encode(data).decode(), | |
"mime_type": "audio/pcm", | |
} | |
] | |
} | |
} | |
await self.out_queue.put(msg) | |
async def receive_audio(self): | |
"Background task to reads from the websocket and write pcm chunks to the output queue" | |
async for raw_response in self.ws: | |
# Other things could be returned here, but we'll ignore those for now. | |
response = json.loads(raw_response.decode("ascii")) | |
try: | |
b64data = response["serverContent"]["modelTurn"]["parts"][0][ | |
"inlineData" | |
]["data"] | |
except KeyError: | |
pass | |
else: | |
pcm_data = base64.b64decode(b64data) | |
self.audio_in_queue.put_nowait(pcm_data) | |
try: | |
turn_complete = response["serverContent"]["turnComplete"] | |
except KeyError: | |
pass | |
else: | |
if turn_complete: | |
# If you interrupt the model, it sends an end_of_turn. | |
# For interruptions to work, we need to empty out the audio queue | |
# Because it may have loaded much more audio than has played yet. | |
print("\nEnd of turn") | |
while not self.audio_in_queue.empty(): | |
self.audio_in_queue.get_nowait() | |
async def play_audio(self): | |
pya = pyaudio.PyAudio() | |
stream = pya.open( | |
format=FORMAT, channels=CHANNELS, rate=RECEIVE_SAMPLE_RATE, output=True | |
) | |
while True: | |
bytestream = await self.audio_in_queue.get() | |
await asyncio.to_thread(stream.write, bytestream) | |
async def run(self): | |
"""Takes audio chunks off the input queue, and writes them to files. | |
Splits and displays files if the queue pauses for more than `max_pause`. | |
""" | |
try: | |
async with ( | |
await connect( | |
uri, additional_headers={"Content-Type": "application/json"} | |
) as ws, | |
asyncio.TaskGroup() as tg, | |
): | |
self.ws = ws | |
await self.startup() | |
self.audio_in_queue = asyncio.Queue() | |
self.out_queue = asyncio.Queue(maxsize=5) | |
send_text_task = tg.create_task(self.send_text()) | |
tg.create_task(self.send_realtime()) | |
tg.create_task(self.listen_audio()) | |
if self.video_mode == "camera": | |
tg.create_task(self.get_frames()) | |
elif self.video_mode == "screen": | |
tg.create_task(self.get_screen()) | |
tg.create_task(self.receive_audio()) | |
tg.create_task(self.play_audio()) | |
await send_text_task | |
raise asyncio.CancelledError("User requested exit") | |
except asyncio.CancelledError: | |
pass | |
except ExceptionGroup as EG: | |
self.audio_stream.close() | |
traceback.print_exception(EG) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--mode", | |
type=str, | |
default=DEFAULT_MODE, | |
help="pixels to stream from", | |
choices=["camera", "screen", "none"], | |
) | |
args = parser.parse_args() | |
main = AudioLoop(video_mode=args.mode) | |
asyncio.run(main.run()) |