Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import io | |
import tempfile | |
from pydub import AudioSegment | |
from dataclasses import dataclass, field | |
class AppState: | |
stream: np.ndarray | None = None | |
sampling_rate: int = 0 | |
pause_detected: bool = False | |
stopped: bool = False | |
started_talking: bool = False | |
conversation: list = field(default_factory=list) | |
# Process audio input and detect pauses | |
def process_audio(audio: tuple, state: AppState): | |
if state.stream is None: | |
state.stream = audio[1] | |
state.sampling_rate = audio[0] | |
else: | |
state.stream = np.concatenate((state.stream, audio[1])) | |
pause_detected = len(state.stream) > state.sampling_rate * 1 | |
state.pause_detected = pause_detected | |
if state.pause_detected: | |
return gr.Audio(recording=False), state # Stop recording | |
return None, state | |
# Generate response based on input type (text or audio) | |
def response(input_data, state: AppState, input_type: str): | |
if input_type == "text": | |
# Ensure text input is handled correctly | |
user_message = input_data.strip() # Prevent errors from empty inputs | |
if not user_message: | |
return "Please enter a valid message.", state | |
state.conversation.append({"role": "user", "content": user_message}) | |
bot_response = f"Echo: {user_message}" # Simulated bot response | |
state.conversation.append({"role": "assistant", "content": bot_response}) | |
return bot_response, state | |
if input_type == "audio" and state.pause_detected: | |
# Convert audio to WAV and store in conversation history | |
audio_buffer = io.BytesIO() | |
segment = AudioSegment( | |
state.stream.tobytes(), | |
frame_rate=state.sampling_rate, | |
sample_width=state.stream.dtype.itemsize, | |
channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1] | |
) | |
segment.export(audio_buffer, format="wav") | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
f.write(audio_buffer.getvalue()) | |
state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}}) | |
chatbot_response = b"Simulated response audio content" | |
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: | |
f.write(chatbot_response) | |
state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}}) | |
yield None, state | |
return None, state # Handle unexpected input cases gracefully | |
# Start recording audio input | |
def start_recording_user(state: AppState): | |
if not state.stopped: | |
return gr.Audio(recording=True) | |
# Gradio app setup | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.Audio(label="Input Audio", type="numpy") | |
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...") | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) | |
state = gr.State(value=AppState()) | |
# Handle audio input streaming | |
stream = input_audio.stream( | |
process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30 | |
) | |
# Handle text input submission | |
text_submit = text_input.submit( | |
lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state] | |
) | |
# Handle audio stop recording | |
respond = input_audio.stop_recording( | |
lambda s: response(None, s, "audio"), [state], [output_audio, state] | |
) | |
respond.then(lambda s: s.conversation, [state], [chatbot]) | |
# Restart recording after audio playback ends | |
restart = output_audio.stop(start_recording_user, [state], [input_audio]) | |
# Stop conversation button | |
cancel = gr.Button("Stop Conversation", variant="stop") | |
cancel.click( | |
lambda: (AppState(stopped=True), gr.Audio(recording=False)), | |
None, [state, input_audio], cancels=[respond, restart] | |
) | |
if __name__ == "__main__": | |
demo.launch() |