wisalQA_P1 / Live_audio.py
afouda's picture
Upload Live_audio.py
a059ad0 verified
raw
history blame
13.5 kB
import asyncio
import base64
import os
import time
from io import BytesIO
from google.genai import types
from google.genai.types import (
LiveConnectConfig,
SpeechConfig,
VoiceConfig,
PrebuiltVoiceConfig,
Content,
Part,
)
import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
WebRTC,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from gradio.utils import get_space
from PIL import Image
# ------------------------------------------
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel
# ------------------------------------------------
from dotenv import load_dotenv
load_dotenv()
import os
import io
import asyncio
from pydub import AudioSegment
# Gemini: google-genai
from google import genai
# ---------------------------------------------------
# VAD imports from reference code
import collections
import webrtcvad
import time
# helper functions
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo"
TAVILY_API_KEY="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"
OPENAI_API_KEY="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"
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E"
QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io"
OPENAI_API_KEY="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"
WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud"
WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw"
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_audio2(data: np.ndarray) -> bytes:
"""Encode Audio data to send to the server"""
return data.tobytes()
import soundfile as sf
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
return buffer.getvalue()
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
"""
Convert a NumPy audio array to WAV bytes.
Args:
audio_array (np.ndarray): Audio signal (1D or 2D).
sample_rate (int): Sample rate in Hz.
Returns:
bytes: WAV-formatted audio data.
"""
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
buffer.seek(0)
return buffer.read()
# webrtc handler class
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API with chained latency calculation."""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"},
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
self.prompt_dict = prompt_dict
# self.model = "gemini-2.5-flash-preview-tts"
self.model = "gemini-2.0-flash-live-001"
self.t2t_model = "gemini-2.0-flash"
self.s2t_model = "gemini-2.0-flash"
# --- VAD Initialization ---
self.vad = webrtcvad.Vad(3)
self.VAD_RATE = 16000
self.VAD_FRAME_MS = 20
self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0))
self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2
padding_ms = 300
self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS
self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames)
self.vad_ratio = 0.9
self.vad_triggered = False
self.wav_data = bytearray()
self.internal_buffer = bytearray()
self.end_of_speech_time: float | None = None
self.first_latency_calculated: bool = False
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
prompt_dict=self.prompt_dict,
)
def t2t(self, text: str) -> str:
print(f"Sending text to Gemini: {text}")
response = self.chat.send_message(text)
print(f"Received response from Gemini: {response.text}")
return response.text
def s2t(self, audio) -> str:
response = self.s2t_client.models.generate_content(
model=self.s2t_model,
contents=[
types.Part.from_bytes(data=audio, mime_type='audio/wav'),
'Generate a transcript of the speech.'
]
)
return response.text
async def start_up(self):
# Flag for if we are using text-to-text in the middle of the chain or not.
self.t2t_bool = False
self.sys_prompt = None
self.t2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
self.s2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))#, http_options={"api_version": "v1alpha"})
if self.sys_prompt is not None:
chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt)
else:
chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.")
self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config)
self.t2s_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
voice_name = "Puck"
if self.t2t_bool:
sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
Always be clear, non-judgmental, and supportive."""
else:
sys_instruction = self.sys_prompt
if sys_instruction is not None:
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
)
),
system_instruction=Content(parts=[Part.from_text(text=sys_instruction)])
)
else:
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
)
),
)
async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session:
async for text_from_user in self.stream():
print("--------------------------------------------")
print(f"Received text from user and reading aloud: {text_from_user}")
print("--------------------------------------------")
if text_from_user and text_from_user.strip():
if self.t2t_bool:
prompt = f"""
You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
Always be clear, non-judgmental, and supportive.
{text_from_user}
"""
else:
prompt = text_from_user
await session.send_client_content(
turns=types.Content(
role='user', parts=[types.Part(text=prompt)]))
async for resp_chunk in session.receive():
if resp_chunk.data:
array = np.frombuffer(resp_chunk.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
# Get the text message to be converted to speech
text_to_speak = await self.input_queue.get()
yield text_to_speak
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
sr, array = frame
audio_bytes = array.tobytes()
self.internal_buffer.extend(audio_bytes)
while len(self.internal_buffer) >= self.VAD_FRAME_BYTES:
vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES]
self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:]
is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE)
if not self.vad_triggered:
self.vad_ring_buffer.append((vad_frame, is_speech))
num_voiced = len([f for f, speech in self.vad_ring_buffer if speech])
if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
print("Speech detected, starting to record...")
self.vad_triggered = True
for f, s in self.vad_ring_buffer:
self.wav_data.extend(f)
self.vad_ring_buffer.clear()
else:
self.wav_data.extend(vad_frame)
self.vad_ring_buffer.append((vad_frame, is_speech))
num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech])
if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
print("End of speech detected.")
self.end_of_speech_time = time.monotonic()
self.vad_triggered = False
full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16)
audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
text_input = self.s2t(audio_input_wav)
if text_input and text_input.strip():
if self.t2t_bool:
text_message = self.t2t(text_input)
else:
text_message = text_input
self.input_queue.put_nowait(text_message)
else:
print("STT returned empty transcript, skipping.")
self.vad_ring_buffer.clear()
self.wav_data = bytearray()
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
with gr.Blocks() as demo:
gr.Markdown("# Gemini Chained Speech-to-Speech Demo")
# for audio modality
# with gr.Row(visible=(modality_selector.value == "audio")) as row2:
with gr.Row() as row2:
with gr.Column(): # Optional, can be removed if not needed
webrtc2 = WebRTC(
label="Audio Chat",
modality="audio",
mode="send-receive",
elem_id="audio-source",
rtc_configuration=get_cloudflare_turn_credentials_async,
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
pulse_color="rgb(255, 255, 255)",
icon_button_color="rgb(255, 255, 255)",
)
# Corrected inputs and outputs for webrtc2.stream to use webrtc2
webrtc2.stream(
GeminiHandler(),
inputs=[webrtc2], # Was webrtc
outputs=[webrtc2],# Was webrtc
time_limit=180 if get_space() else None,
concurrency_limit=2 if get_space() else None,
)
if __name__ == "__main__":
demo.launch(server_port=7860)