test-gpt-omni / app.py
TuringsSolutions's picture
Update app.py
466a41a verified
raw
history blame
4.13 kB
import gradio as gr
import numpy as np
import io
import tempfile
from pydub import AudioSegment
from dataclasses import dataclass, field
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
stopped: bool = False
started_talking: bool = False
conversation: list = field(default_factory=list) # Proper use of default_factory
# Function to process audio input and detect pauses
def process_audio(audio: tuple, state: AppState):
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
# Detect if a pause has occurred (for simplicity, use 1-second threshold)
pause_detected = len(state.stream) > state.sampling_rate * 1
state.pause_detected = pause_detected
if state.pause_detected:
return gr.Audio(recording=False), state # Stop recording
return None, state
# Generate chatbot response based on user input (audio or text)
def response(user_input, state: AppState, input_type: str):
if input_type == "text":
# Handle text input
state.conversation.append({"role": "user", "content": user_input})
bot_response = f"Echo: {user_input}" # Simulate response
state.conversation.append({"role": "assistant", "content": bot_response})
return bot_response, state
# Handle audio input if pause was detected
if not state.pause_detected:
return None, state
# Convert audio to WAV and store in conversation history
audio_buffer = io.BytesIO()
segment = AudioSegment(
state.stream.tobytes(),
frame_rate=state.sampling_rate,
sample_width=state.stream.dtype.itemsize,
channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1]
)
segment.export(audio_buffer, format="wav")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_buffer.getvalue())
state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}})
# Simulate bot's response (replace with mini omni logic)
chatbot_response = b"Simulated response audio content" # Placeholder
output_buffer = chatbot_response
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
f.write(output_buffer)
state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}})
yield None, state
# Start recording audio input
def start_recording_user(state: AppState):
if not state.stopped:
return gr.Audio(recording=True)
# Gradio interface setup
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...")
with gr.Column():
chatbot = gr.Chatbot(label="Conversation", type="messages")
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
state = gr.State(value=AppState())
# Handle audio input streaming
stream = input_audio.stream(
process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30
)
# Handle responses for both text and audio inputs
text_submit = text_input.submit(
lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state]
)
respond = input_audio.stop_recording(response, [None, state, "audio"], [output_audio, state])
respond.then(lambda s: s.conversation, [state], [chatbot])
# Restart recording when audio playback stops
restart = output_audio.stop(start_recording_user, [state], [input_audio])
# Stop button to cancel the conversation
cancel = gr.Button("Stop Conversation", variant="stop")
cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond, restart])
if __name__ == "__main__":
demo.launch()