File size: 4,185 Bytes
b6ab738
 
 
 
 
 
 
 
 
 
 
 
 
 
94cfbe9
b6ab738
94cfbe9
b6ab738
 
 
 
 
 
 
466a41a
b6ab738
 
 
bacc85c
b6ab738
 
94cfbe9
 
466a41a
bacc85c
 
 
 
 
 
 
466a41a
 
 
bacc85c
 
 
 
 
 
 
 
 
 
b6ab738
bacc85c
 
 
b6ab738
bacc85c
 
 
 
b6ab738
bacc85c
b6ab738
bacc85c
b6ab738
466a41a
b6ab738
 
 
 
94cfbe9
b6ab738
 
 
bacc85c
466a41a
b6ab738
 
 
466a41a
b6ab738
 
466a41a
b6ab738
 
 
466a41a
94cfbe9
466a41a
 
 
94cfbe9
 
 
 
 
b6ab738
 
94cfbe9
b6ab738
466a41a
94cfbe9
b6ab738
94cfbe9
 
 
 
b6ab738
 
466a41a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
import io
import tempfile
from pydub import AudioSegment
from dataclasses import dataclass, field

@dataclass
class AppState:
    stream: np.ndarray | None = None
    sampling_rate: int = 0
    pause_detected: bool = False
    stopped: bool = False
    started_talking: bool = False
    conversation: list = field(default_factory=list)

# Process audio input and detect pauses
def process_audio(audio: tuple, state: AppState):
    if state.stream is None:
        state.stream = audio[1]
        state.sampling_rate = audio[0]
    else:
        state.stream = np.concatenate((state.stream, audio[1]))

    pause_detected = len(state.stream) > state.sampling_rate * 1
    state.pause_detected = pause_detected

    if state.pause_detected:
        return gr.Audio(recording=False), state  # Stop recording
    return None, state

# Generate response based on input type (text or audio)
def response(input_data, state: AppState, input_type: str):
    if input_type == "text":
        # Ensure text input is handled correctly
        user_message = input_data.strip()  # Prevent errors from empty inputs
        if not user_message:
            return "Please enter a valid message.", state

        state.conversation.append({"role": "user", "content": user_message})
        bot_response = f"Echo: {user_message}"  # Simulated bot response
        state.conversation.append({"role": "assistant", "content": bot_response})
        return bot_response, state

    if input_type == "audio" and state.pause_detected:
        # Convert audio to WAV and store in conversation history
        audio_buffer = io.BytesIO()
        segment = AudioSegment(
            state.stream.tobytes(),
            frame_rate=state.sampling_rate,
            sample_width=state.stream.dtype.itemsize,
            channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1]
        )
        segment.export(audio_buffer, format="wav")

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            f.write(audio_buffer.getvalue())
        state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}})

        chatbot_response = b"Simulated response audio content"
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
            f.write(chatbot_response)
        state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}})

        yield None, state

    return None, state  # Handle unexpected input cases gracefully

# Start recording audio input
def start_recording_user(state: AppState):
    if not state.stopped:
        return gr.Audio(recording=True)

# Gradio app setup
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_audio = gr.Audio(label="Input Audio", type="numpy")
            text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...")
        with gr.Column():
            chatbot = gr.Chatbot(label="Conversation", type="messages")
            output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)

    state = gr.State(value=AppState())

    # Handle audio input streaming
    stream = input_audio.stream(
        process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30
    )

    # Handle text input submission
    text_submit = text_input.submit(
        lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state]
    )

    # Handle audio stop recording
    respond = input_audio.stop_recording(
        lambda s: response(None, s, "audio"), [state], [output_audio, state]
    )
    respond.then(lambda s: s.conversation, [state], [chatbot])

    # Restart recording after audio playback ends
    restart = output_audio.stop(start_recording_user, [state], [input_audio])

    # Stop conversation button
    cancel = gr.Button("Stop Conversation", variant="stop")
    cancel.click(
        lambda: (AppState(stopped=True), gr.Audio(recording=False)),
        None, [state, input_audio], cancels=[respond, restart]
    )

if __name__ == "__main__":
    demo.launch()