Spaces:
Runtime error
Runtime error
import asyncio | |
import base64 | |
import os | |
import time | |
from io import BytesIO | |
from google.genai import types | |
from google.genai.types import ( | |
LiveConnectConfig, | |
SpeechConfig, | |
VoiceConfig, | |
PrebuiltVoiceConfig, | |
Content, | |
Part, | |
) | |
import gradio as gr | |
import numpy as np | |
import websockets | |
from dotenv import load_dotenv | |
from fastrtc import ( | |
AsyncAudioVideoStreamHandler, | |
Stream, | |
WebRTC, | |
get_cloudflare_turn_credentials_async, | |
wait_for_item, | |
) | |
from google import genai | |
from gradio.utils import get_space | |
from PIL import Image | |
# ------------------------------------------ | |
import asyncio | |
import base64 | |
import json | |
import os | |
import pathlib | |
from typing import AsyncGenerator, Literal | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse | |
from fastrtc import ( | |
AsyncStreamHandler, | |
Stream, | |
get_cloudflare_turn_credentials_async, | |
wait_for_item, | |
) | |
from google import genai | |
from google.genai.types import ( | |
LiveConnectConfig, | |
PrebuiltVoiceConfig, | |
SpeechConfig, | |
VoiceConfig, | |
) | |
from gradio.utils import get_space | |
from pydantic import BaseModel | |
# ------------------------------------------------ | |
from dotenv import load_dotenv | |
load_dotenv() | |
import os | |
import io | |
import asyncio | |
from pydub import AudioSegment | |
# Gemini: google-genai | |
from google import genai | |
# --------------------------------------------------- | |
# VAD imports from reference code | |
import collections | |
import webrtcvad | |
import time | |
# helper functions | |
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" | |
TAVILY_API_KEY="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" | |
OPENAI_API_KEY="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" | |
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" | |
QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" | |
OPENAI_API_KEY="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" | |
WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" | |
WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" | |
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
def encode_audio(data: np.ndarray) -> dict: | |
"""Encode Audio data to send to the server""" | |
return { | |
"mime_type": "audio/pcm", | |
"data": base64.b64encode(data.tobytes()).decode("UTF-8"), | |
} | |
def encode_audio2(data: np.ndarray) -> bytes: | |
"""Encode Audio data to send to the server""" | |
return data.tobytes() | |
import soundfile as sf | |
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
buffer = io.BytesIO() | |
sf.write(buffer, audio_array, sample_rate, format='WAV') | |
return buffer.getvalue() | |
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
""" | |
Convert a NumPy audio array to WAV bytes. | |
Args: | |
audio_array (np.ndarray): Audio signal (1D or 2D). | |
sample_rate (int): Sample rate in Hz. | |
Returns: | |
bytes: WAV-formatted audio data. | |
""" | |
buffer = io.BytesIO() | |
sf.write(buffer, audio_array, sample_rate, format='WAV') | |
buffer.seek(0) | |
return buffer.read() | |
# webrtc handler class | |
class GeminiHandler(AsyncStreamHandler): | |
"""Handler for the Gemini API with chained latency calculation.""" | |
def __init__( | |
self, | |
expected_layout: Literal["mono"] = "mono", | |
output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"}, | |
) -> None: | |
super().__init__( | |
expected_layout, | |
output_sample_rate, | |
input_sample_rate=16000, | |
) | |
self.input_queue: asyncio.Queue = asyncio.Queue() | |
self.output_queue: asyncio.Queue = asyncio.Queue() | |
self.quit: asyncio.Event = asyncio.Event() | |
self.prompt_dict = prompt_dict | |
# self.model = "gemini-2.5-flash-preview-tts" | |
self.model = "gemini-2.0-flash-live-001" | |
self.t2t_model = "gemini-2.0-flash" | |
self.s2t_model = "gemini-2.0-flash" | |
# --- VAD Initialization --- | |
self.vad = webrtcvad.Vad(3) | |
self.VAD_RATE = 16000 | |
self.VAD_FRAME_MS = 20 | |
self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0)) | |
self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2 | |
padding_ms = 300 | |
self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS | |
self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames) | |
self.vad_ratio = 0.9 | |
self.vad_triggered = False | |
self.wav_data = bytearray() | |
self.internal_buffer = bytearray() | |
self.end_of_speech_time: float | None = None | |
self.first_latency_calculated: bool = False | |
def copy(self) -> "GeminiHandler": | |
return GeminiHandler( | |
expected_layout="mono", | |
output_sample_rate=self.output_sample_rate, | |
prompt_dict=self.prompt_dict, | |
) | |
def t2t(self, text: str) -> str: | |
print(f"Sending text to Gemini: {text}") | |
response = self.chat.send_message(text) | |
print(f"Received response from Gemini: {response.text}") | |
return response.text | |
def s2t(self, audio) -> str: | |
response = self.s2t_client.models.generate_content( | |
model=self.s2t_model, | |
contents=[ | |
types.Part.from_bytes(data=audio, mime_type='audio/wav'), | |
'Generate a transcript of the speech.' | |
] | |
) | |
return response.text | |
async def start_up(self): | |
# Flag for if we are using text-to-text in the middle of the chain or not. | |
self.t2t_bool = False | |
self.sys_prompt = None | |
self.t2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
self.s2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))#, http_options={"api_version": "v1alpha"}) | |
if self.sys_prompt is not None: | |
chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt) | |
else: | |
chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.") | |
self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config) | |
self.t2s_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) | |
voice_name = "Puck" | |
if self.t2t_bool: | |
sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . | |
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). | |
Always be clear, non-judgmental, and supportive.""" | |
else: | |
sys_instruction = self.sys_prompt | |
if sys_instruction is not None: | |
config = LiveConnectConfig( | |
response_modalities=["AUDIO"], | |
speech_config=SpeechConfig( | |
voice_config=VoiceConfig( | |
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
) | |
), | |
system_instruction=Content(parts=[Part.from_text(text=sys_instruction)]) | |
) | |
else: | |
config = LiveConnectConfig( | |
response_modalities=["AUDIO"], | |
speech_config=SpeechConfig( | |
voice_config=VoiceConfig( | |
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
) | |
), | |
) | |
async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session: | |
async for text_from_user in self.stream(): | |
print("--------------------------------------------") | |
print(f"Received text from user and reading aloud: {text_from_user}") | |
print("--------------------------------------------") | |
if text_from_user and text_from_user.strip(): | |
if self.t2t_bool: | |
prompt = f""" | |
You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . | |
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). | |
Always be clear, non-judgmental, and supportive. | |
{text_from_user} | |
""" | |
else: | |
prompt = text_from_user | |
await session.send_client_content( | |
turns=types.Content( | |
role='user', parts=[types.Part(text=prompt)])) | |
async for resp_chunk in session.receive(): | |
if resp_chunk.data: | |
array = np.frombuffer(resp_chunk.data, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
async def stream(self) -> AsyncGenerator[bytes, None]: | |
while not self.quit.is_set(): | |
try: | |
# Get the text message to be converted to speech | |
text_to_speak = await self.input_queue.get() | |
yield text_to_speak | |
except (asyncio.TimeoutError, TimeoutError): | |
pass | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
sr, array = frame | |
audio_bytes = array.tobytes() | |
self.internal_buffer.extend(audio_bytes) | |
while len(self.internal_buffer) >= self.VAD_FRAME_BYTES: | |
vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES] | |
self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:] | |
is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE) | |
if not self.vad_triggered: | |
self.vad_ring_buffer.append((vad_frame, is_speech)) | |
num_voiced = len([f for f, speech in self.vad_ring_buffer if speech]) | |
if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
print("Speech detected, starting to record...") | |
self.vad_triggered = True | |
for f, s in self.vad_ring_buffer: | |
self.wav_data.extend(f) | |
self.vad_ring_buffer.clear() | |
else: | |
self.wav_data.extend(vad_frame) | |
self.vad_ring_buffer.append((vad_frame, is_speech)) | |
num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech]) | |
if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
print("End of speech detected.") | |
self.end_of_speech_time = time.monotonic() | |
self.vad_triggered = False | |
full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16) | |
audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr) | |
text_input = self.s2t(audio_input_wav) | |
if text_input and text_input.strip(): | |
if self.t2t_bool: | |
text_message = self.t2t(text_input) | |
else: | |
text_message = text_input | |
self.input_queue.put_nowait(text_message) | |
else: | |
print("STT returned empty transcript, skipping.") | |
self.vad_ring_buffer.clear() | |
self.wav_data = bytearray() | |
async def emit(self) -> tuple[int, np.ndarray] | None: | |
return await wait_for_item(self.output_queue) | |
def shutdown(self) -> None: | |
self.quit.set() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gemini Chained Speech-to-Speech Demo") | |
# for audio modality | |
# with gr.Row(visible=(modality_selector.value == "audio")) as row2: | |
with gr.Row() as row2: | |
with gr.Column(): # Optional, can be removed if not needed | |
webrtc2 = WebRTC( | |
label="Audio Chat", | |
modality="audio", | |
mode="send-receive", | |
elem_id="audio-source", | |
rtc_configuration=get_cloudflare_turn_credentials_async, | |
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", | |
pulse_color="rgb(255, 255, 255)", | |
icon_button_color="rgb(255, 255, 255)", | |
) | |
# Corrected inputs and outputs for webrtc2.stream to use webrtc2 | |
webrtc2.stream( | |
GeminiHandler(), | |
inputs=[webrtc2], # Was webrtc | |
outputs=[webrtc2],# Was webrtc | |
time_limit=180 if get_space() else None, | |
concurrency_limit=2 if get_space() else None, | |
) | |
if __name__ == "__main__": | |
demo.launch(server_port=7860) |