File size: 5,736 Bytes
a8f8e8d
005a35d
319ab81
 
 
 
 
b9fb710
bc8d394
190e978
1ba4a0c
f797e13
 
 
 
6e090f6
f797e13
 
 
 
a41954d
f797e13
 
 
190e978
1ba4a0c
 
afbc2cb
376e42a
f797e13
 
 
 
8d48726
 
f797e13
8d48726
6ef5b41
f797e13
 
8d48726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9fb710
 
8d48726
 
 
1ba4a0c
8d48726
 
 
1632271
8d48726
 
 
1ba4a0c
f797e13
 
 
 
8d48726
 
 
f797e13
8d48726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f797e13
afbc2cb
 
 
 
319ab81
 
afbc2cb
 
 
 
 
 
 
 
 
 
 
319ab81
3a9492e
 
 
319ab81
3a9492e
319ab81
 
 
 
a41954d
319ab81
 
afbc2cb
 
f0af7d8
319ab81
f797e13
 
afbc2cb
 
 
005a35d
 
afbc2cb
005a35d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sounddevice
import threading
from starlette.applications import Starlette
from starlette.routing import Route
from sse_starlette.sse import EventSourceResponse
import uvicorn
import asyncio
import numpy as np
import time
import os
import httpx
from queue import Queue
import logging
from datetime import UTC, datetime, timedelta
from time import sleep
import pickle

import speech_recognition as sr

from audio_utils import get_microphone, get_speech_recognizer, get_all_audio_queue, to_audio_array, AudioChunk
from starlette.middleware.cors import CORSMiddleware

logger = logging.getLogger(__name__)

TRANSCRIBING_SERVER = os.getenv('TRANSCRIBING_SERVER', "http://localhost:3535/transcribe")


def main(transcriptions_queue):
    recording_duration = 1
    sample_rate = 16000
    energy_threshold = 300

    data_queue = Queue()
    microphone = get_microphone(sample_rate=sample_rate)
    print('microphone is:', microphone)

    with microphone:
        speech_recognizer = get_speech_recognizer(energy_threshold=energy_threshold)
        speech_recognizer.adjust_for_ambient_noise(source=microphone)

    def record_callback(_, audio: sr.AudioData) -> None:
        data = audio.get_raw_data()
        data_queue.put(data)

    speech_recognizer.listen_in_background(source=microphone, callback=record_callback, phrase_time_limit=recording_duration)

    print("\n🎤 Microphone is now listening...\n")

    prev_audio_array = None
    current_audio_chunk = AudioChunk(start_time=datetime.now(tz=UTC))

    while True:
        try:
            now = datetime.now(tz=UTC)
            # Pull raw recorded audio from the queue.
            if not data_queue.empty():
                # Store end time if we're over the recording time limit.
                if now - current_audio_chunk.start_time > timedelta(seconds=recording_duration):
                    current_audio_chunk.end_time = now

                # Get audio data from queue
                audio_data = get_all_audio_queue(data_queue)
                audio_np_array = to_audio_array(audio_data)

                if current_audio_chunk.is_complete:
                    print('start serialize')
                    if prev_audio_array is not None:
                        serialized = pickle.dumps(
                            np.concatenate((
                                prev_audio_array,
                                current_audio_chunk.audio_array
                            ))
                        )
                    else:
                        serialized = pickle.dumps(current_audio_chunk.audio_array)
                    prev_audio_array = current_audio_chunk.audio_array
                    print('end serialize')

                    start = time.time()
                    print('start req')
                    response = httpx.post(TRANSCRIBING_SERVER, data=serialized)
                    print('req status', response.status_code)
                    transcription = response.json()['transcribe']
                    print('req done', response.text, response.status_code, time.time() - start)
                    transcriptions_queue.put(transcription)

                    # text = transcribe_model.transcribe(current_audio_chunk.audio_array)
                    # sentence = Sentence(
                    #     start_time=current_audio_chunk.start_time, end_time=current_audio_chunk.end_time, text=text
                    # )
                    current_audio_chunk = AudioChunk(
                        audio_array=audio_np_array, start_time=datetime.now(tz=UTC)
                    )
                    # print(sentence.text)  # noqa: T201
                else:
                    current_audio_chunk.update_array(audio_np_array)

                # Flush stdout
                print("", end="", flush=True)  # noqa: T201

                # Infinite loops are bad for processors, must sleep.
                sleep(0.25)
        except KeyboardInterrupt:
            current_audio_chunk.end_time = datetime.now(tz=UTC)
            if current_audio_chunk.is_complete:
                logger.warning("⚠️ Transcribing last chunk...")
                # text = transcribe_model.transcribe(current_audio_chunk.audio_array)
                # sentence = Sentence(
                #     start_time=current_audio_chunk.start_time, end_time=current_audio_chunk.end_time, text=text
                # )
                # print(sentence.text)  # noqa: T201
            break


    # for i in range(minimum, maximum + 1):
    #     await asyncio.sleep(0.9)
    #     yield dict(data=i)

async def sse(request):
    async def event_publisher():
        try:
            while True:
                text = transcriptions_queue.get()
                yield dict(data=text)
                await asyncio.sleep(0.2)
        except asyncio.CancelledError as e:
            print(f"Disconnected from client (via refresh/close) {request.client}")

    return EventSourceResponse(event_publisher())


def test(request):
    return "hello world"

routes = [
    Route('/', endpoint=test),
    Route("/test", endpoint=sse)
]

app = Starlette(debug=True, routes=routes)
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_methods=['*'], allow_headers=['*'])


def server(transcriptions_queue):
    app.state.transcriptions_queue = transcriptions_queue
    uvicorn.run(app, host="0.0.0.0", port=8343, log_level='info')


if __name__ == '__main__':
    transcriptions_queue = Queue()

    main_thread = threading.Thread(target=main, args=(transcriptions_queue,))
    main_thread.start()

    server_thread = threading.Thread(target=server, args=(transcriptions_queue,))
    server_thread.start()

    main_thread.join()
    server_thread.join()