Twelve2five's picture
Update app.py
42a325e verified
raw
history blame
6.63 kB
import os
import time
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from elevenlabs import ElevenLabs
from fastapi import FastAPI
from fastrtc import (
Stream,
get_stt_model,
ReplyOnPause,
AdditionalOutputs
)
from gradio.utils import get_space
import requests
import io
import soundfile as sf
from gtts import gTTS
import re
# Load environment variables
load_dotenv()
# Initialize clients
elevenlabs_client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
stt_model = get_stt_model()
class DeepSeekAPI:
def __init__(self, api_key):
self.api_key = api_key
def chat_completion(self, messages, temperature=0.7, max_tokens=512):
url = "https://api.deepseek.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": "deepseek-chat",
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
response = requests.post(url, json=payload, headers=headers)
# Check for error response
if response.status_code != 200:
print(f"DeepSeek API error: {response.status_code} - {response.text}")
return {"choices": [{"message": {"content": "I'm sorry, I encountered an error processing your request."}}]}
return response.json()
deepseek_client = DeepSeekAPI(api_key=os.getenv("DEEPSEEK_API_KEY"))
def response(
audio: tuple[int, np.ndarray],
chatbot: list[dict] | None = None,
):
chatbot = chatbot or []
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
# Convert speech to text
text = stt_model.stt(audio)
print("prompt:", text)
# Add user message to chat
chatbot.append({"role": "user", "content": text})
yield AdditionalOutputs(chatbot)
# Get AI response
messages.append({"role": "user", "content": text})
# Call DeepSeek API
response_data = deepseek_client.chat_completion(messages)
response_text = response_data["choices"][0]["message"]["content"]
# Add AI response to chat
chatbot.append({"role": "assistant", "content": response_text})
# Convert response to speech
if os.getenv("ELEVENLABS_API_KEY"):
try:
print(f"Generating ElevenLabs speech for response")
# Use the streaming API for better experience
for chunk in elevenlabs_client.text_to_speech.convert_as_stream(
text=response_text,
voice_id="Antoni",
model_id="eleven_monolingual_v1",
output_format="pcm_24000"
):
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
yield (24000, audio_array)
except Exception as e:
print(f"ElevenLabs error: {e}, falling back to gTTS")
# Fall back to gTTS
yield from use_gtts_for_text(response_text)
else:
# Fall back to gTTS
print("ElevenLabs API key not found, using gTTS...")
yield from use_gtts_for_text(response_text)
yield AdditionalOutputs(chatbot)
def use_gtts_for_text(text):
"""Helper function to generate speech with gTTS for the entire text"""
try:
# Split text into sentences for better results
sentences = re.split(r'(?<=[.!?])\s+', text)
for sentence in sentences:
if not sentence.strip():
continue
mp3_fp = io.BytesIO()
print(f"Using gTTS for sentence: {sentence[:30]}...")
tts = gTTS(text=sentence, lang='en-us', tld='com', slow=False)
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
data, samplerate = sf.read(mp3_fp)
if len(data.shape) > 1 and data.shape[1] > 1:
data = data[:, 0]
if samplerate != 24000:
data = np.interp(
np.linspace(0, len(data), int(len(data) * 24000 / samplerate)),
np.arange(len(data)),
data
)
data = (data * 32767).astype(np.int16)
# Ensure buffer size is even
if len(data) % 2 != 0:
data = np.append(data, [0])
# Reshape and yield in chunks
chunk_size = 4800
for i in range(0, len(data), chunk_size):
chunk = data[i:i+chunk_size]
if len(chunk) > 0:
if len(chunk) % 2 != 0:
chunk = np.append(chunk, [0])
chunk = chunk.reshape(1, -1)
yield (24000, chunk)
except Exception as e:
print(f"gTTS error: {e}")
yield None
# Enhanced WebRTC configuration with more STUN/TURN servers
rtc_configuration = {
"iceServers": [
{"urls": ["stun:stun.l.google.com:19302", "stun:stun1.l.google.com:19302"]},
{
"urls": ["turn:openrelay.metered.ca:80"],
"username": "openrelayproject",
"credential": "openrelayproject"
},
{
"urls": ["turn:openrelay.metered.ca:443?transport=tcp"],
"username": "openrelayproject",
"credential": "openrelayproject"
}
],
"iceCandidatePoolSize": 10
}
# Create Gradio chatbot and stream - following the exact cookbook pattern
chatbot = gr.Chatbot(type="messages")
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response, input_sample_rate=16000),
additional_outputs_handler=lambda a, b: b,
additional_inputs=[chatbot],
additional_outputs=[chatbot],
rtc_configuration=rtc_configuration,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
ui_args={"title": "LLM Voice Chat (Powered by DeepSeek & ElevenLabs)"}
)
# Mount the Stream UI to FastAPI
app = FastAPI()
app = gr.mount_gradio_app(app, stream.ui, path="/")
# Only for local development
if __name__ == "__main__":
import uvicorn
os.environ["GRADIO_SSR_MODE"] = "false"
if get_space():
# When running in HF Spaces, use their port and host
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
else:
# For local development
stream.ui.launch(server_port=7860)