test-gpt-omni / app.py
TuringsSolutions's picture
Update app.py
bacc85c verified
raw
history blame
4.19 kB
import gradio as gr
import numpy as np
import io
import tempfile
from pydub import AudioSegment
from dataclasses import dataclass, field
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
stopped: bool = False
started_talking: bool = False
conversation: list = field(default_factory=list)
# Process audio input and detect pauses
def process_audio(audio: tuple, state: AppState):
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
pause_detected = len(state.stream) > state.sampling_rate * 1
state.pause_detected = pause_detected
if state.pause_detected:
return gr.Audio(recording=False), state # Stop recording
return None, state
# Generate response based on input type (text or audio)
def response(input_data, state: AppState, input_type: str):
if input_type == "text":
# Ensure text input is handled correctly
user_message = input_data.strip() # Prevent errors from empty inputs
if not user_message:
return "Please enter a valid message.", state
state.conversation.append({"role": "user", "content": user_message})
bot_response = f"Echo: {user_message}" # Simulated bot response
state.conversation.append({"role": "assistant", "content": bot_response})
return bot_response, state
if input_type == "audio" and state.pause_detected:
# Convert audio to WAV and store in conversation history
audio_buffer = io.BytesIO()
segment = AudioSegment(
state.stream.tobytes(),
frame_rate=state.sampling_rate,
sample_width=state.stream.dtype.itemsize,
channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1]
)
segment.export(audio_buffer, format="wav")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_buffer.getvalue())
state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}})
chatbot_response = b"Simulated response audio content"
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
f.write(chatbot_response)
state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}})
yield None, state
return None, state # Handle unexpected input cases gracefully
# Start recording audio input
def start_recording_user(state: AppState):
if not state.stopped:
return gr.Audio(recording=True)
# Gradio app setup
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Input Audio", type="numpy")
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...")
with gr.Column():
chatbot = gr.Chatbot(label="Conversation", type="messages")
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
state = gr.State(value=AppState())
# Handle audio input streaming
stream = input_audio.stream(
process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30
)
# Handle text input submission
text_submit = text_input.submit(
lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state]
)
# Handle audio stop recording
respond = input_audio.stop_recording(
lambda s: response(None, s, "audio"), [state], [output_audio, state]
)
respond.then(lambda s: s.conversation, [state], [chatbot])
# Restart recording after audio playback ends
restart = output_audio.stop(start_recording_user, [state], [input_audio])
# Stop conversation button
cancel = gr.Button("Stop Conversation", variant="stop")
cancel.click(
lambda: (AppState(stopped=True), gr.Audio(recording=False)),
None, [state, input_audio], cancels=[respond, restart]
)
if __name__ == "__main__":
demo.launch()