freddyaboulton's picture
Update app.py
bf0cc6a verified
raw
history blame
4.49 kB
import logging
import base64
import io
import os
from threading import Thread
import gradio as gr
import numpy as np
import requests
from gradio_webrtc import ReplyOnPause, WebRTC, AdditionalOutputs
from pydub import AudioSegment
from twilio.rest import Client
from server import serve
logging.basicConfig(level=logging.WARNING)
file_handler = logging.FileHandler("gradio_webrtc.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
IP = "0.0.0.0"
PORT = 60808
thread = Thread(target=serve, daemon=True)
thread.start()
API_URL = "http://0.0.0.0:60808/chat"
# Only needed if deploying on cloud provider
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096
def response(audio: tuple[int, np.ndarray], conversation: list[dict], img: str | None):
sampling_rate, audio_np = audio
audio_np = audio_np.squeeze()
audio_buffer = io.BytesIO()
segment = AudioSegment(
audio_np.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_np.dtype.itemsize,
channels=1,
)
segment.export(audio_buffer, format="wav")
conversation.append({"role": "user", "content": gr.Audio((sampling_rate, audio_np))})
conversation.append({"role": "assistant", "content": ""})
base64_encoded = str(base64.b64encode(audio_buffer.getvalue()), encoding="utf-8")
if API_URL is not None:
output_audio_bytes = b""
files = {"audio": base64_encoded}
if img is not None:
files["image"] = str(base64.b64encode(open(img, "rb").read()), encoding="utf-8")
print("sending request to server")
resp_text = ""
with requests.post(API_URL, json=files, stream=True) as response:
try:
buffer = b''
for chunk in response.iter_content(chunk_size=2048):
buffer += chunk
while b'\r\n--frame\r\n' in buffer:
frame, buffer = buffer.split(b'\r\n--frame\r\n', 1)
if b'Content-Type: audio/wav' in frame:
audio_data = frame.split(b'\r\n\r\n', 1)[1]
# audio_data = base64.b64decode(audio_data)
output_audio_bytes += audio_data
audio_array = np.frombuffer(audio_data, dtype=np.int16).reshape(1, -1)
yield (OUT_RATE, audio_array, "mono")
elif b'Content-Type: text/plain' in frame:
text_data = frame.split(b'\r\n\r\n', 1)[1].decode()
resp_text += text_data
conversation[-1]["content"] = resp_text
yield AdditionalOutputs(conversation)
except Exception as e:
raise Exception(f"Error during audio streaming: {e}") from e
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Mini-Omni-2 Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="send-receive",
modality="audio",
)
with gr.Column():
img = gr.Image(label="Image", type="filepath")
with gr.Column():
conversation = gr.Chatbot(label="Conversation", type="messages")
audio.stream(
fn=ReplyOnPause(
response, output_sample_rate=OUT_RATE, output_frame_size=480
),
inputs=[audio, conversation, img],
outputs=[audio],
time_limit=90,
)
audio.on_additional_outputs(lambda c: c, outputs=[conversation])
demo.launch()