File size: 3,755 Bytes
b6ab738
 
 
 
 
 
 
 
 
 
 
 
 
 
94cfbe9
b6ab738
94cfbe9
b6ab738
 
 
 
 
 
 
466a41a
b6ab738
 
 
94cfbe9
b6ab738
 
94cfbe9
 
466a41a
94cfbe9
 
466a41a
 
 
b6ab738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94cfbe9
b6ab738
94cfbe9
b6ab738
 
 
 
466a41a
b6ab738
 
 
 
94cfbe9
b6ab738
 
 
402a272
466a41a
b6ab738
 
 
466a41a
b6ab738
 
466a41a
b6ab738
 
 
466a41a
94cfbe9
466a41a
 
 
94cfbe9
 
 
 
 
b6ab738
 
94cfbe9
b6ab738
466a41a
94cfbe9
b6ab738
94cfbe9
 
 
 
b6ab738
 
466a41a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import io
import tempfile
from pydub import AudioSegment
from dataclasses import dataclass, field

@dataclass
class AppState:
    stream: np.ndarray | None = None
    sampling_rate: int = 0
    pause_detected: bool = False
    stopped: bool = False
    started_talking: bool = False
    conversation: list = field(default_factory=list)

# Process audio input and detect pauses
def process_audio(audio: tuple, state: AppState):
    if state.stream is None:
        state.stream = audio[1]
        state.sampling_rate = audio[0]
    else:
        state.stream = np.concatenate((state.stream, audio[1]))

    pause_detected = len(state.stream) > state.sampling_rate * 1
    state.pause_detected = pause_detected

    if state.pause_detected:
        return gr.Audio(recording=False), state
    return None, state

# Generate response based on input type (text or audio)
def response(input_data, state: AppState, input_type: str):
    if input_type == "text":
        state.conversation.append({"role": "user", "content": input_data})
        bot_response = f"Echo: {input_data}"
        state.conversation.append({"role": "assistant", "content": bot_response})
        return bot_response, state

    if not state.pause_detected:
        return None, state

    audio_buffer = io.BytesIO()
    segment = AudioSegment(
        state.stream.tobytes(),
        frame_rate=state.sampling_rate,
        sample_width=state.stream.dtype.itemsize,
        channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1]
    )
    segment.export(audio_buffer, format="wav")

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
        f.write(audio_buffer.getvalue())
    state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}})

    chatbot_response = b"Simulated response audio content"
    with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
        f.write(chatbot_response)
    state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}})

    yield None, state

# Start recording audio input
def start_recording_user(state: AppState):
    if not state.stopped:
        return gr.Audio(recording=True)

# Gradio app setup
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_audio = gr.Audio(label="Input Audio", type="numpy")  # No 'source' argument
            text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...")
        with gr.Column():
            chatbot = gr.Chatbot(label="Conversation", type="messages")
            output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)

    state = gr.State(value=AppState())

    # Handle audio input streaming
    stream = input_audio.stream(
        process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30
    )

    # Handle text input submission
    text_submit = text_input.submit(
        lambda txt, s: response(txt, s, "text"), [text_input, state], [chatbot, state]
    )

    # Handle audio stop recording
    respond = input_audio.stop_recording(
        lambda s: response(None, s, "audio"), [state], [output_audio, state]
    )
    respond.then(lambda s: s.conversation, [state], [chatbot])

    # Restart recording after audio playback ends
    restart = output_audio.stop(start_recording_user, [state], [input_audio])

    # Stop conversation button
    cancel = gr.Button("Stop Conversation", variant="stop")
    cancel.click(
        lambda: (AppState(stopped=True), gr.Audio(recording=False)),
        None, [state, input_audio], cancels=[respond, restart]
    )

if __name__ == "__main__":
    demo.launch()