afouda commited on
Commit
a059ad0
·
verified ·
1 Parent(s): e914f90

Upload Live_audio.py

Browse files
Files changed (1) hide show
  1. Live_audio.py +348 -0
Live_audio.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+ from google.genai import types
7
+ from google.genai.types import (
8
+ LiveConnectConfig,
9
+ SpeechConfig,
10
+ VoiceConfig,
11
+ PrebuiltVoiceConfig,
12
+ Content,
13
+ Part,
14
+ )
15
+ import gradio as gr
16
+ import numpy as np
17
+ import websockets
18
+ from dotenv import load_dotenv
19
+ from fastrtc import (
20
+ AsyncAudioVideoStreamHandler,
21
+ Stream,
22
+ WebRTC,
23
+ get_cloudflare_turn_credentials_async,
24
+ wait_for_item,
25
+ )
26
+ from google import genai
27
+ from gradio.utils import get_space
28
+ from PIL import Image
29
+
30
+ # ------------------------------------------
31
+ import asyncio
32
+ import base64
33
+ import json
34
+ import os
35
+ import pathlib
36
+ from typing import AsyncGenerator, Literal
37
+
38
+ import gradio as gr
39
+ import numpy as np
40
+ from dotenv import load_dotenv
41
+ from fastapi import FastAPI
42
+ from fastapi.responses import HTMLResponse
43
+ from fastrtc import (
44
+ AsyncStreamHandler,
45
+ Stream,
46
+ get_cloudflare_turn_credentials_async,
47
+ wait_for_item,
48
+ )
49
+ from google import genai
50
+ from google.genai.types import (
51
+ LiveConnectConfig,
52
+ PrebuiltVoiceConfig,
53
+ SpeechConfig,
54
+ VoiceConfig,
55
+ )
56
+ from gradio.utils import get_space
57
+ from pydantic import BaseModel
58
+ # ------------------------------------------------
59
+ from dotenv import load_dotenv
60
+ load_dotenv()
61
+ import os
62
+ import io
63
+ import asyncio
64
+ from pydub import AudioSegment
65
+
66
+ # Gemini: google-genai
67
+ from google import genai
68
+ # ---------------------------------------------------
69
+ # VAD imports from reference code
70
+ import collections
71
+ import webrtcvad
72
+ import time
73
+
74
+ # helper functions
75
+ GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo"
76
+
77
+ TAVILY_API_KEY="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"
78
+
79
+ OPENAI_API_KEY="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"
80
+
81
+ QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E"
82
+
83
+ QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io"
84
+
85
+ OPENAI_API_KEY="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"
86
+
87
+ WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud"
88
+
89
+ WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw"
90
+
91
+ DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"
92
+
93
+ DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"
94
+ def encode_audio(data: np.ndarray) -> dict:
95
+ """Encode Audio data to send to the server"""
96
+ return {
97
+ "mime_type": "audio/pcm",
98
+ "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
99
+ }
100
+ def encode_audio2(data: np.ndarray) -> bytes:
101
+ """Encode Audio data to send to the server"""
102
+ return data.tobytes()
103
+
104
+ import soundfile as sf
105
+
106
+ def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
107
+ buffer = io.BytesIO()
108
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
109
+ return buffer.getvalue()
110
+
111
+
112
+ def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
113
+ """
114
+ Convert a NumPy audio array to WAV bytes.
115
+
116
+ Args:
117
+ audio_array (np.ndarray): Audio signal (1D or 2D).
118
+ sample_rate (int): Sample rate in Hz.
119
+
120
+ Returns:
121
+ bytes: WAV-formatted audio data.
122
+ """
123
+ buffer = io.BytesIO()
124
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
125
+ buffer.seek(0)
126
+ return buffer.read()
127
+ # webrtc handler class
128
+ class GeminiHandler(AsyncStreamHandler):
129
+ """Handler for the Gemini API with chained latency calculation."""
130
+
131
+ def __init__(
132
+ self,
133
+ expected_layout: Literal["mono"] = "mono",
134
+ output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"},
135
+ ) -> None:
136
+ super().__init__(
137
+ expected_layout,
138
+ output_sample_rate,
139
+ input_sample_rate=16000,
140
+ )
141
+ self.input_queue: asyncio.Queue = asyncio.Queue()
142
+ self.output_queue: asyncio.Queue = asyncio.Queue()
143
+ self.quit: asyncio.Event = asyncio.Event()
144
+ self.prompt_dict = prompt_dict
145
+ # self.model = "gemini-2.5-flash-preview-tts"
146
+ self.model = "gemini-2.0-flash-live-001"
147
+ self.t2t_model = "gemini-2.0-flash"
148
+ self.s2t_model = "gemini-2.0-flash"
149
+
150
+ # --- VAD Initialization ---
151
+ self.vad = webrtcvad.Vad(3)
152
+ self.VAD_RATE = 16000
153
+ self.VAD_FRAME_MS = 20
154
+ self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0))
155
+ self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2
156
+ padding_ms = 300
157
+ self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS
158
+ self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames)
159
+ self.vad_ratio = 0.9
160
+ self.vad_triggered = False
161
+ self.wav_data = bytearray()
162
+ self.internal_buffer = bytearray()
163
+
164
+ self.end_of_speech_time: float | None = None
165
+ self.first_latency_calculated: bool = False
166
+
167
+ def copy(self) -> "GeminiHandler":
168
+ return GeminiHandler(
169
+ expected_layout="mono",
170
+ output_sample_rate=self.output_sample_rate,
171
+ prompt_dict=self.prompt_dict,
172
+ )
173
+
174
+ def t2t(self, text: str) -> str:
175
+ print(f"Sending text to Gemini: {text}")
176
+ response = self.chat.send_message(text)
177
+ print(f"Received response from Gemini: {response.text}")
178
+ return response.text
179
+
180
+ def s2t(self, audio) -> str:
181
+ response = self.s2t_client.models.generate_content(
182
+ model=self.s2t_model,
183
+ contents=[
184
+ types.Part.from_bytes(data=audio, mime_type='audio/wav'),
185
+ 'Generate a transcript of the speech.'
186
+ ]
187
+ )
188
+ return response.text
189
+
190
+ async def start_up(self):
191
+ # Flag for if we are using text-to-text in the middle of the chain or not.
192
+ self.t2t_bool = False
193
+ self.sys_prompt = None
194
+
195
+ self.t2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
196
+ self.s2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))#, http_options={"api_version": "v1alpha"})
197
+ if self.sys_prompt is not None:
198
+ chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt)
199
+ else:
200
+ chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.")
201
+ self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config)
202
+ self.t2s_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
203
+
204
+ voice_name = "Puck"
205
+ if self.t2t_bool:
206
+ sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
207
+ Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
208
+ Always be clear, non-judgmental, and supportive."""
209
+ else:
210
+ sys_instruction = self.sys_prompt
211
+
212
+ if sys_instruction is not None:
213
+ config = LiveConnectConfig(
214
+ response_modalities=["AUDIO"],
215
+ speech_config=SpeechConfig(
216
+ voice_config=VoiceConfig(
217
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
218
+ )
219
+ ),
220
+ system_instruction=Content(parts=[Part.from_text(text=sys_instruction)])
221
+ )
222
+ else:
223
+ config = LiveConnectConfig(
224
+ response_modalities=["AUDIO"],
225
+ speech_config=SpeechConfig(
226
+ voice_config=VoiceConfig(
227
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
228
+ )
229
+ ),
230
+ )
231
+
232
+ async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session:
233
+ async for text_from_user in self.stream():
234
+ print("--------------------------------------------")
235
+ print(f"Received text from user and reading aloud: {text_from_user}")
236
+ print("--------------------------------------------")
237
+ if text_from_user and text_from_user.strip():
238
+ if self.t2t_bool:
239
+ prompt = f"""
240
+ You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
241
+ Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
242
+ Always be clear, non-judgmental, and supportive.
243
+
244
+ {text_from_user}
245
+ """
246
+ else:
247
+ prompt = text_from_user
248
+ await session.send_client_content(
249
+ turns=types.Content(
250
+ role='user', parts=[types.Part(text=prompt)]))
251
+ async for resp_chunk in session.receive():
252
+ if resp_chunk.data:
253
+ array = np.frombuffer(resp_chunk.data, dtype=np.int16)
254
+ self.output_queue.put_nowait((self.output_sample_rate, array))
255
+
256
+
257
+ async def stream(self) -> AsyncGenerator[bytes, None]:
258
+ while not self.quit.is_set():
259
+ try:
260
+ # Get the text message to be converted to speech
261
+ text_to_speak = await self.input_queue.get()
262
+ yield text_to_speak
263
+ except (asyncio.TimeoutError, TimeoutError):
264
+ pass
265
+
266
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
267
+ sr, array = frame
268
+ audio_bytes = array.tobytes()
269
+ self.internal_buffer.extend(audio_bytes)
270
+
271
+ while len(self.internal_buffer) >= self.VAD_FRAME_BYTES:
272
+ vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES]
273
+ self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:]
274
+ is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE)
275
+
276
+ if not self.vad_triggered:
277
+ self.vad_ring_buffer.append((vad_frame, is_speech))
278
+ num_voiced = len([f for f, speech in self.vad_ring_buffer if speech])
279
+ if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
280
+ print("Speech detected, starting to record...")
281
+ self.vad_triggered = True
282
+ for f, s in self.vad_ring_buffer:
283
+ self.wav_data.extend(f)
284
+ self.vad_ring_buffer.clear()
285
+ else:
286
+ self.wav_data.extend(vad_frame)
287
+ self.vad_ring_buffer.append((vad_frame, is_speech))
288
+ num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech])
289
+ if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
290
+ print("End of speech detected.")
291
+
292
+
293
+ self.end_of_speech_time = time.monotonic()
294
+
295
+ self.vad_triggered = False
296
+ full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16)
297
+ audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
298
+
299
+ text_input = self.s2t(audio_input_wav)
300
+ if text_input and text_input.strip():
301
+ if self.t2t_bool:
302
+ text_message = self.t2t(text_input)
303
+ else:
304
+ text_message = text_input
305
+ self.input_queue.put_nowait(text_message)
306
+ else:
307
+ print("STT returned empty transcript, skipping.")
308
+
309
+ self.vad_ring_buffer.clear()
310
+ self.wav_data = bytearray()
311
+
312
+ async def emit(self) -> tuple[int, np.ndarray] | None:
313
+
314
+ return await wait_for_item(self.output_queue)
315
+
316
+ def shutdown(self) -> None:
317
+
318
+ self.quit.set()
319
+
320
+
321
+ with gr.Blocks() as demo:
322
+ gr.Markdown("# Gemini Chained Speech-to-Speech Demo")
323
+
324
+ # for audio modality
325
+ # with gr.Row(visible=(modality_selector.value == "audio")) as row2:
326
+ with gr.Row() as row2:
327
+ with gr.Column(): # Optional, can be removed if not needed
328
+ webrtc2 = WebRTC(
329
+ label="Audio Chat",
330
+ modality="audio",
331
+ mode="send-receive",
332
+ elem_id="audio-source",
333
+ rtc_configuration=get_cloudflare_turn_credentials_async,
334
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
335
+ pulse_color="rgb(255, 255, 255)",
336
+ icon_button_color="rgb(255, 255, 255)",
337
+ )
338
+ # Corrected inputs and outputs for webrtc2.stream to use webrtc2
339
+ webrtc2.stream(
340
+ GeminiHandler(),
341
+ inputs=[webrtc2], # Was webrtc
342
+ outputs=[webrtc2],# Was webrtc
343
+ time_limit=180 if get_space() else None,
344
+ concurrency_limit=2 if get_space() else None,
345
+ )
346
+
347
+ if __name__ == "__main__":
348
+ demo.launch(server_port=7860)