File size: 6,625 Bytes
4df6700
 
 
 
 
 
c4620f8
4df6700
 
 
 
 
 
c4620f8
4df6700
 
 
 
 
 
 
 
 
 
 
 
 
 
013f6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4df6700
013f6a1
4df6700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
013f6a1
 
 
 
4df6700
 
 
 
 
c4620f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4df6700
 
 
c4620f8
 
40785f3
c4620f8
40785f3
 
c4620f8
 
 
40785f3
c4620f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40785f3
c4620f8
40785f3
 
38edbec
d518218
 
42a325e
d518218
 
 
 
 
 
 
 
 
42a325e
38edbec
42a325e
d518218
 
42a325e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd4006
42a325e
 
 
c4620f8
42a325e
38edbec
42a325e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import time
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from elevenlabs import ElevenLabs
from fastapi import FastAPI
from fastrtc import (
    Stream,
    get_stt_model,
    ReplyOnPause,
    AdditionalOutputs
)
from gradio.utils import get_space

import requests
import io
import soundfile as sf
from gtts import gTTS
import re

# Load environment variables
load_dotenv()

# Initialize clients
elevenlabs_client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
stt_model = get_stt_model()

class DeepSeekAPI:
    def __init__(self, api_key):
        self.api_key = api_key
        
    def chat_completion(self, messages, temperature=0.7, max_tokens=512):
        url = "https://api.deepseek.com/v1/chat/completions"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = {
            "model": "deepseek-chat",
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }
        response = requests.post(url, json=payload, headers=headers)
        
        # Check for error response
        if response.status_code != 200:
            print(f"DeepSeek API error: {response.status_code} - {response.text}")
            return {"choices": [{"message": {"content": "I'm sorry, I encountered an error processing your request."}}]}
            
        return response.json()

deepseek_client = DeepSeekAPI(api_key=os.getenv("DEEPSEEK_API_KEY"))

def response(
    audio: tuple[int, np.ndarray],
    chatbot: list[dict] | None = None,
):
    chatbot = chatbot or []
    messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
    
    # Convert speech to text
    text = stt_model.stt(audio)
    print("prompt:", text)
    
    # Add user message to chat
    chatbot.append({"role": "user", "content": text})
    yield AdditionalOutputs(chatbot)
    
    # Get AI response
    messages.append({"role": "user", "content": text})
    
    # Call DeepSeek API
    response_data = deepseek_client.chat_completion(messages)
    response_text = response_data["choices"][0]["message"]["content"]
    
    # Add AI response to chat
    chatbot.append({"role": "assistant", "content": response_text})
    
    # Convert response to speech
    if os.getenv("ELEVENLABS_API_KEY"):
        try:
            print(f"Generating ElevenLabs speech for response")
            
            # Use the streaming API for better experience
            for chunk in elevenlabs_client.text_to_speech.convert_as_stream(
                text=response_text,
                voice_id="Antoni",
                model_id="eleven_monolingual_v1",
                output_format="pcm_24000"
            ):
                audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
                yield (24000, audio_array)
                
        except Exception as e:
            print(f"ElevenLabs error: {e}, falling back to gTTS")
            # Fall back to gTTS
            yield from use_gtts_for_text(response_text)
    else:
        # Fall back to gTTS
        print("ElevenLabs API key not found, using gTTS...")
        yield from use_gtts_for_text(response_text)
    
    yield AdditionalOutputs(chatbot)

def use_gtts_for_text(text):
    """Helper function to generate speech with gTTS for the entire text"""
    try:
        # Split text into sentences for better results
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        for sentence in sentences:
            if not sentence.strip():
                continue
                
            mp3_fp = io.BytesIO()
            print(f"Using gTTS for sentence: {sentence[:30]}...")
            tts = gTTS(text=sentence, lang='en-us', tld='com', slow=False)
            tts.write_to_fp(mp3_fp)
            mp3_fp.seek(0)
            
            data, samplerate = sf.read(mp3_fp)
            
            if len(data.shape) > 1 and data.shape[1] > 1:
                data = data[:, 0]
            
            if samplerate != 24000:
                data = np.interp(
                    np.linspace(0, len(data), int(len(data) * 24000 / samplerate)),
                    np.arange(len(data)),
                    data
                )
            
            data = (data * 32767).astype(np.int16)
            
            # Ensure buffer size is even
            if len(data) % 2 != 0:
                data = np.append(data, [0])
            
            # Reshape and yield in chunks
            chunk_size = 4800
            for i in range(0, len(data), chunk_size):
                chunk = data[i:i+chunk_size]
                if len(chunk) > 0:
                    if len(chunk) % 2 != 0:
                        chunk = np.append(chunk, [0])
                    chunk = chunk.reshape(1, -1)
                    yield (24000, chunk)
    except Exception as e:
        print(f"gTTS error: {e}")
        yield None

# Enhanced WebRTC configuration with more STUN/TURN servers
rtc_configuration = {
    "iceServers": [
        {"urls": ["stun:stun.l.google.com:19302", "stun:stun1.l.google.com:19302"]},
        {
            "urls": ["turn:openrelay.metered.ca:80"],
            "username": "openrelayproject",
            "credential": "openrelayproject"
        },
        {
            "urls": ["turn:openrelay.metered.ca:443?transport=tcp"],
            "username": "openrelayproject",
            "credential": "openrelayproject"
        }
    ],
    "iceCandidatePoolSize": 10
}

# Create Gradio chatbot and stream - following the exact cookbook pattern
chatbot = gr.Chatbot(type="messages")
stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(response, input_sample_rate=16000),
    additional_outputs_handler=lambda a, b: b,
    additional_inputs=[chatbot],
    additional_outputs=[chatbot],
    rtc_configuration=rtc_configuration,
    concurrency_limit=5 if get_space() else None,
    time_limit=90 if get_space() else None,
    ui_args={"title": "LLM Voice Chat (Powered by DeepSeek & ElevenLabs)"}
)

# Mount the Stream UI to FastAPI
app = FastAPI()
app = gr.mount_gradio_app(app, stream.ui, path="/")

# Only for local development
if __name__ == "__main__":
    import uvicorn
    
    os.environ["GRADIO_SSR_MODE"] = "false"
    
    if get_space():
        # When running in HF Spaces, use their port and host
        port = int(os.environ.get("PORT", 7860))
        uvicorn.run(app, host="0.0.0.0", port=port)
    else:
        # For local development
        stream.ui.launch(server_port=7860)