|
import gradio as gr |
|
import numpy as np |
|
import io |
|
import tempfile |
|
from pydub import AudioSegment |
|
from dataclasses import dataclass, field |
|
|
|
@dataclass |
|
class AppState: |
|
stream: np.ndarray | None = None |
|
sampling_rate: int = 0 |
|
pause_detected: bool = False |
|
stopped: bool = False |
|
started_talking: bool = False |
|
conversation: list = field(default_factory=list) |
|
|
|
|
|
def process_audio(audio: tuple, state: AppState): |
|
if state.stream is None: |
|
state.stream = audio[1] |
|
state.sampling_rate = audio[0] |
|
else: |
|
state.stream = np.concatenate((state.stream, audio[1])) |
|
|
|
|
|
pause_detected = len(state.stream) > state.sampling_rate * 1 |
|
state.pause_detected = pause_detected |
|
|
|
if state.pause_detected: |
|
return gr.Audio(recording=False), state |
|
return None, state |
|
|
|
|
|
def response(user_input, state: AppState, input_type: str): |
|
if input_type == "text": |
|
|
|
state.conversation.append({"role": "user", "content": user_input}) |
|
bot_response = f"Echo: {user_input}" |
|
state.conversation.append({"role": "assistant", "content": bot_response}) |
|
return bot_response, state |
|
|
|
|
|
if not state.pause_detected: |
|
return None, state |
|
|
|
|
|
audio_buffer = io.BytesIO() |
|
segment = AudioSegment( |
|
state.stream.tobytes(), |
|
frame_rate=state.sampling_rate, |
|
sample_width=state.stream.dtype.itemsize, |
|
channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1] |
|
) |
|
segment.export(audio_buffer, format="wav") |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
f.write(audio_buffer.getvalue()) |
|
state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}}) |
|
|
|
|
|
chatbot_response = b"Simulated response audio content" |
|
output_buffer = chatbot_response |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: |
|
f.write(output_buffer) |
|
state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}}) |
|
|
|
yield None, state |
|
|
|
|
|
def start_recording_user(state: AppState): |
|
if not state.stopped: |
|
return gr.Audio(recording=True) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy") |
|
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...") |
|
with gr.Column(): |
|
chatbot = gr.Chatbot(label="Conversation", type="messages") |
|
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) |
|
|
|
state = gr.State(value=AppState()) |
|
|
|
|
|
stream = input_audio.stream( |
|
process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30 |
|
) |
|
|
|
|
|
text_submit = text_input.submit( |
|
lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state] |
|
) |
|
respond = input_audio.stop_recording(response, [None, state, "audio"], [output_audio, state]) |
|
respond.then(lambda s: s.conversation, [state], [chatbot]) |
|
|
|
|
|
restart = output_audio.stop(start_recording_user, [state], [input_audio]) |
|
|
|
|
|
cancel = gr.Button("Stop Conversation", variant="stop") |
|
cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond, restart]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |