akhaliq HF Staff commited on
Commit
16e5f81
·
verified ·
1 Parent(s): 4078b6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_webrtc import WebRTC, StreamHandler
3
+ import websockets.sync.client
4
+ import numpy as np
5
+ import json
6
+ import base64
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # Get API key and validate
14
+ API_KEY = os.getenv('GOOGLE_API_KEY')
15
+ if not API_KEY:
16
+ raise ValueError("GOOGLE_API_KEY not found in environment variables. Please set it in your .env file.")
17
+
18
+ class GeminiHandler(StreamHandler):
19
+ def __init__(self,
20
+ expected_layout="mono",
21
+ output_sample_rate=24000,
22
+ output_frame_size=480) -> None:
23
+ self.ws = None
24
+ self.HOST = 'generativelanguage.googleapis.com'
25
+ self.API_KEY = API_KEY # Use the validated API key
26
+ self.MODEL = 'models/gemini-2.0-flash-exp'
27
+ self.ws_url = f'wss://{self.HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.API_KEY}'
28
+ self.all_output_data = None
29
+ super().__init__(expected_layout, output_sample_rate, output_frame_size,
30
+ input_sample_rate=24000)
31
+
32
+ def copy(self):
33
+ return GeminiHandler(
34
+ expected_layout=self.expected_layout,
35
+ output_sample_rate=self.output_sample_rate,
36
+ output_frame_size=self.output_frame_size
37
+ )
38
+
39
+ def encode_audio_input(self, data):
40
+ encoded = base64.b64encode(data.tobytes()).decode('UTF-8')
41
+ return {
42
+ 'realtimeInput': {
43
+ 'mediaChunks': [{
44
+ 'mimeType': f'audio/pcm;rate={self.output_sample_rate}',
45
+ 'data': encoded,
46
+ }],
47
+ },
48
+ }
49
+
50
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
51
+ try:
52
+ if not self.ws:
53
+ self.ws = websockets.sync.client.connect(self.ws_url)
54
+ initial_request = {
55
+ 'setup': {
56
+ 'model': self.MODEL,
57
+ }
58
+ }
59
+ self.ws.send(json.dumps(initial_request))
60
+ setup_response = json.loads(self.ws.recv())
61
+ print(f"Setup response: {setup_response}")
62
+
63
+ _, array = frame
64
+ array = array.squeeze()
65
+ # Send audio data
66
+ audio_message = self.encode_audio_input(array)
67
+ self.ws.send(json.dumps(audio_message))
68
+ except Exception as e:
69
+ print(f"Error in receive: {str(e)}")
70
+ if self.ws:
71
+ self.ws.close()
72
+ self.ws = None
73
+
74
+ def generator(self):
75
+ while True:
76
+ if not self.ws:
77
+ yield None
78
+ continue
79
+
80
+ message = self.ws.recv()
81
+ msg = json.loads(message)
82
+
83
+ if 'serverContent' in msg:
84
+ content = msg['serverContent'].get('modelTurn', {})
85
+ for part in content.get('parts', []):
86
+ data = part.get('inlineData', {}).get('data', '')
87
+ if data:
88
+ audio_data = base64.b64decode(data)
89
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
90
+ if self.all_output_data is None:
91
+ self.all_output_data = audio_array
92
+ else:
93
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
94
+
95
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
96
+ yield (self.output_sample_rate,
97
+ self.all_output_data[:self.output_frame_size].reshape(1, -1))
98
+ self.all_output_data = self.all_output_data[self.output_frame_size:]
99
+
100
+ def emit(self) -> tuple[int, np.ndarray] | None:
101
+ if not self.ws:
102
+ return None
103
+ if not hasattr(self, '_generator'):
104
+ self._generator = self.generator()
105
+ try:
106
+ return next(self._generator)
107
+ except StopIteration:
108
+ self.reset()
109
+ return None
110
+
111
+ def reset(self) -> None:
112
+ if hasattr(self, '_generator'):
113
+ delattr(self, '_generator')
114
+ self.all_output_data = None
115
+
116
+ def shutdown(self) -> None:
117
+ if self.ws:
118
+ self.ws.close()
119
+
120
+ # Create Gradio interface
121
+ with gr.Blocks() as demo:
122
+ gr.HTML("""
123
+ <div style='text-align: center'>
124
+ <h1>Gemini 2.0 Voice Chat</h1>
125
+ <p>Speak with Gemini using real-time audio streaming</p>
126
+ </div>
127
+ """)
128
+
129
+ chatbot = gr.Chatbot(type='messages')
130
+ webrtc = WebRTC(
131
+ label="Conversation",
132
+ modality="audio",
133
+ mode="send-receive"
134
+ )
135
+
136
+ webrtc.stream(
137
+ GeminiHandler(),
138
+ inputs=[webrtc],
139
+ outputs=[webrtc],
140
+ time_limit=90
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ demo.launch()