Twelve2five's picture
Update app.py
f558bc0 verified
raw
history blame
9.17 kB
import os
import time
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from elevenlabs import ElevenLabs
from fastrtc import (
Stream,
get_stt_model,
ReplyOnPause,
AdditionalOutputs
)
from gradio.utils import get_space
import requests
import io
import soundfile as sf
from gtts import gTTS
import re
import logging
# Set up logging for WebRTC debugging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("fastrtc-voice-assistant")
# Load environment variables
load_dotenv()
# Enable WebRTC debug tracing
os.environ["WEBRTC_TRACE"] = "WEBRTC_TRACE_ALL"
# Initialize clients
logger.info("Initializing clients...")
elevenlabs_client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
stt_model = get_stt_model()
logger.info("Clients initialized")
class DeepSeekAPI:
def __init__(self, api_key):
self.api_key = api_key
def chat_completion(self, messages, temperature=0.7, max_tokens=512):
url = "https://api.deepseek.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": "deepseek-chat",
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
response = requests.post(url, json=payload, headers=headers)
# Check for error response
if response.status_code != 200:
logger.error(f"DeepSeek API error: {response.status_code} - {response.text}")
return {"choices": [{"message": {"content": "I'm sorry, I encountered an error processing your request."}}]}
return response.json()
deepseek_client = DeepSeekAPI(api_key=os.getenv("DEEPSEEK_API_KEY"))
def response(
audio: tuple[int, np.ndarray],
chatbot: list[dict] | None = None,
):
chatbot = chatbot or []
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
# Convert speech to text
logger.info("Converting speech to text...")
text = stt_model.stt(audio)
logger.info(f"User said: {text}")
# Add user message to chat
chatbot.append({"role": "user", "content": text})
yield AdditionalOutputs(chatbot)
# Get AI response
messages.append({"role": "user", "content": text})
# Call DeepSeek API
logger.info("Calling DeepSeek API...")
response_data = deepseek_client.chat_completion(messages)
response_text = response_data["choices"][0]["message"]["content"]
logger.info(f"DeepSeek response: {response_text[:50]}...")
# Add AI response to chat
chatbot.append({"role": "assistant", "content": response_text})
# Convert response to speech
if os.getenv("ELEVENLABS_API_KEY"):
try:
logger.info("Using ElevenLabs for speech generation")
# Use the streaming API for better experience
for chunk in elevenlabs_client.text_to_speech.convert_as_stream(
text=response_text,
voice_id="Antoni",
model_id="eleven_monolingual_v1",
output_format="pcm_24000"
):
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
yield (24000, audio_array)
except Exception as e:
logger.error(f"ElevenLabs error: {e}, falling back to gTTS")
# Fall back to gTTS
yield from use_gtts_for_text(response_text)
else:
# Fall back to gTTS
logger.info("ElevenLabs API key not found, using gTTS...")
yield from use_gtts_for_text(response_text)
yield AdditionalOutputs(chatbot)
def use_gtts_for_text(text):
"""Helper function to generate speech with gTTS for the entire text"""
try:
# Split text into sentences for better results
sentences = re.split(r'(?<=[.!?])\s+', text)
for sentence in sentences:
if not sentence.strip():
continue
mp3_fp = io.BytesIO()
logger.info(f"Using gTTS for: {sentence[:30]}...")
tts = gTTS(text=sentence, lang='en-us', tld='com', slow=False)
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
data, samplerate = sf.read(mp3_fp)
if len(data.shape) > 1 and data.shape[1] > 1:
data = data[:, 0]
if samplerate != 24000:
data = np.interp(
np.linspace(0, len(data), int(len(data) * 24000 / samplerate)),
np.arange(len(data)),
data
)
data = (data * 32767).astype(np.int16)
# Ensure buffer size is even
if len(data) % 2 != 0:
data = np.append(data, [0])
# Reshape and yield in chunks
chunk_size = 4800
for i in range(0, len(data), chunk_size):
chunk = data[i:i+chunk_size]
if len(chunk) > 0:
if len(chunk) % 2 != 0:
chunk = np.append(chunk, [0])
chunk = chunk.reshape(1, -1)
yield (24000, chunk)
except Exception as e:
logger.error(f"gTTS error: {e}")
yield None
# Comprehensive WebRTC configuration with multiple STUN/TURN options
rtc_configuration = {
"iceServers": [
# Google STUN servers
{"urls": ["stun:stun.l.google.com:19302"]},
{"urls": ["stun:stun1.l.google.com:19302"]},
{"urls": ["stun:stun2.l.google.com:19302"]},
{"urls": ["stun:stun3.l.google.com:19302"]},
{"urls": ["stun:stun4.l.google.com:19302"]},
# OpenRelay TURN servers
{
"urls": ["turn:openrelay.metered.ca:80"],
"username": "openrelayproject",
"credential": "openrelayproject"
},
{
"urls": ["turn:openrelay.metered.ca:443"],
"username": "openrelayproject",
"credential": "openrelayproject"
},
{
"urls": ["turn:openrelay.metered.ca:443?transport=tcp"],
"username": "openrelayproject",
"credential": "openrelayproject"
},
# Additional public STUN servers
{"urls": ["stun:stun.stunprotocol.org:3478"]},
{"urls": ["stun:stun.voip.blackberry.com:3478"]},
{"urls": ["stun:stun.nextcloud.com:443"]}
],
"iceCandidatePoolSize": 10,
"bundlePolicy": "max-bundle",
"rtcpMuxPolicy": "require",
"iceTransportPolicy": "all" # Try "relay" if "all" doesn't work
}
# Create a simple wrapper for the webchat UI
with gr.Blocks(title="LLM Voice Chat (Powered by DeepSeek & ElevenLabs)") as demo:
gr.Markdown("# LLM Voice Chat\nPowered by DeepSeek & ElevenLabs")
with gr.Row():
with gr.Column(scale=3):
# Create the chatbot component
chatbot = gr.Chatbot(type="messages")
# For debugging, allow seeing connection status
connection_status = gr.Textbox(label="Connection Status",
value="Ready to connect. Click the microphone button to start.",
interactive=False)
# Display debugging information
debug_info = gr.Textbox(label="Debug Info",
value="WebRTC debug information will appear here.",
interactive=False)
# Button to manually refresh the page
refresh_btn = gr.Button("Refresh Connection")
def refresh_page():
debug_info.value = f"Attempting to refresh connection at {time.time()}"
return "Refreshed", f"Connection refresh attempted at {time.time()}"
refresh_btn.click(
refresh_page,
outputs=[connection_status, debug_info]
)
logger.info("Creating Stream component...")
# Initialize the stream (outside of the blocks context)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response, input_sample_rate=16000),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[chatbot],
additional_outputs=[chatbot],
rtc_configuration=rtc_configuration,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None
)
# Mount the stream to the blocks interface
stream.render()
logger.info("Stream component created and rendered")
# Launch the app
if __name__ == "__main__":
# Local development
logger.info("Running in development mode")
os.environ["GRADIO_SSR_MODE"] = "false"
demo.launch(server_port=7860, share=True)
else:
# Hugging Face Spaces
logger.info("Running in Hugging Face Spaces")
demo.launch()