File size: 9,441 Bytes
a446897
 
 
 
 
 
 
81301f1
 
8b6063a
 
3569cbd
8b6063a
3569cbd
 
 
 
 
 
 
 
 
 
 
2bb7b57
 
f33df9c
65f3e54
44e95cf
81301f1
3569cbd
ee7230e
2bb7b57
81301f1
a446897
4840c8f
 
44e95cf
a446897
 
 
4840c8f
65f3e54
 
 
 
ee7230e
65f3e54
4840c8f
8b6063a
4840c8f
44e95cf
4840c8f
 
44e95cf
 
2bb7b57
 
 
 
44e95cf
af79bf4
8b6063a
af79bf4
8b6063a
2bb7b57
 
 
 
 
 
 
 
4840c8f
3569cbd
a4dee07
a446897
 
a1dd3ea
 
 
 
 
 
 
a446897
 
44e95cf
 
 
8b6063a
 
5df3ca5
44e95cf
 
 
2bb7b57
 
 
 
 
8b6063a
af79bf4
8b6063a
 
 
 
 
 
 
 
 
 
 
 
 
 
af79bf4
8b6063a
 
 
 
 
 
f33df9c
 
 
8b6063a
5cb1ef7
8b6063a
 
 
 
 
 
 
f33df9c
 
 
8b6063a
 
 
 
 
 
 
 
 
 
 
 
f33df9c
 
8b6063a
 
 
 
 
 
 
af79bf4
 
 
44e95cf
 
 
 
 
8b6063a
44e95cf
 
 
8b6063a
bb6818c
 
8b6063a
 
 
 
 
 
 
2bb7b57
bb6818c
44e95cf
2bb7b57
 
 
 
 
8b6063a
af79bf4
 
5cb1ef7
 
 
2bb7b57
 
 
 
5cb1ef7
 
8b6063a
 
 
af79bf4
2bb7b57
 
 
 
af79bf4
 
8b6063a
af79bf4
 
 
bb6818c
af79bf4
8b6063a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6818c
8b6063a
 
 
 
 
 
 
 
 
 
f33df9c
 
8b6063a
 
 
 
 
 
 
af79bf4
 
 
44e95cf
6647d0d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

#![allow(clippy::result_large_err)]

#[cfg(feature = "whisper")]
extern crate whisper;

use aws_config::BehaviorVersion;
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
use aws_sdk_transcribestreaming::types::LanguageCode;
use futures_util::{stream::StreamExt, SinkExt};
use poem::{
    endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
    get, handler,
    listener::TcpListener,
    web::{
        websocket::{Message, WebSocket},
        Data, Query,
    },
    EndpointExt, IntoResponse, Route, Server,
};
use serde::{Deserialize, Serialize};
use tokio::select;
use tracing::{debug, span};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

use crate::{config::*, lesson::*};

mod config;
mod lesson;
mod asr;

#[derive(Clone)]
struct Context {
    lessons_manager: LessonsManager,
}

#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
    tracing_subscriber::registry()
        .with(fmt::layer())
        .with(EnvFilter::from_default_env())
        .init();

    debug!("Transcribe client version: {}", PKG_VERSION);

    let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
    let ctx = Context {
        lessons_manager: LessonsManager::new(&shared_config),
    };

    let app = Route::new()
        .nest(
            "/",
            StaticFilesEndpoint::new("./static")
                .show_files_listing()
                .index_file("index.html"),
        )
        .at("/ws/lesson-speaker", get(stream_speaker))
        .at("/ws/teacher", get(stream_speaker))
        .at("/ws/lesson-listener", get(stream_listener))
        .at("/ws/student", get(stream_listener))
        .at(
            "lesson-speaker",
            StaticFileEndpoint::new("./static/index.html"),
        )
        .at(
            "lesson-listener",
            StaticFileEndpoint::new("./static/index.html"),
        )
        .data(ctx);
    let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
    let listener = TcpListener::bind(addr);
    let server = Server::new(listener);

    select! {
        res = server.run(app) => res,
        _ = tokio::signal::ctrl_c() => {
            tracing::info!("Shutting down");
            Ok(())
        },
    }
}

#[derive(Deserialize, Debug)]
pub struct LessonSpeakerQuery {
    id: u32,
    language: String,
    #[allow(dead_code)] // TODO: use this in the future
    prompt: Option<String>,
}

#[handler]
async fn stream_speaker(
    ctx: Data<&Context>,
    query: Query<LessonSpeakerQuery>,
    ws: WebSocket,
) -> impl IntoResponse {
    let lessons_manager = ctx.lessons_manager.clone();
    ws.on_upgrade(|mut socket| async move {
        let Ok(lang) = query.language.parse::<LanguageCode>() else {
            let _ = socket
                .send(Message::Text(format!("invalid language code: {}", query.language)))
                .await;
            return
        };
        let lesson = lessons_manager
            .create_lesson(
                query.id,
                AsrEngine::AWS,
                lang,
            )
            .await;

        let mut transcribe_rx = lesson.transcript_channel();
        let fut = async {
            loop {
                select! {
                    msg = socket.next() => {
                        let Some(res) = msg else { break };
                        let msg = res?;
                        if msg.is_close() {
                            break
                        }
                        let Message::Binary(bin) = msg else {
                            tracing::warn!("Other: {:?}", msg);
                            continue
                        };
                        let frame = u8_to_i16(&bin);
                        lesson.send(frame).await?
                    },
                    output = transcribe_rx.recv() => {
                        let evt = output?;
                        if evt.is_final {
                            tracing::trace!("Transcribed: {}", evt.transcript);
                        }
                        let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
                        let Ok(json) = serde_json::to_string(&evt) else {
                            tracing::warn!("failed to serialize json: {:?}", evt);
                            continue
                        };
                        socket.send(Message::Text(json)).await?
                    },
                }
            }
            Ok(())
        };

        let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
        let _ = span.enter();
        let res: anyhow::Result<()> = fut.await;
        match res {
            Ok(()) => {
                tracing::info!("lesson speaker closed");
            }
            Err(e) => {
                tracing::warn!("lesson speaker error: {}", e);
            }
        }
    })
}

#[derive(Deserialize, Debug)]
pub struct LessonListenerQuery {
    id: u32,
    language: String,
    voice: String,
}

#[derive(Serialize, Debug)]
#[serde(tag = "type")]
enum LiveLessonTextEvent {
    #[serde(rename = "original")]
    Transcription {
        content: String,
        #[serde(rename = "isFinal")]
        is_final: bool
    },
    Translation { content: String },
    LipSync { visemes: Vec<Viseme> },
}
#[handler]
async fn stream_listener(
    ctx: Data<&Context>,
    query: Query<LessonListenerQuery>,
    ws: WebSocket,
) -> impl IntoResponse {
    let lessons_manager = ctx.lessons_manager.clone();

    ws.on_upgrade(|mut socket| async move {
        let voice_id = match query.voice.parse() {
            Ok(id) => id,
            Err(e) => {
                let _ = socket
                    .send(Message::Text(format!("invalid voice id: {}", e)))
                    .await;
                return;
            }
        };

        let lesson_opt = lessons_manager.get_lesson(query.id).await;
        debug!("listener param = {:?}", query);
        let Some(lesson) = lesson_opt else {
            let _ = socket
                .send(Message::Text("lesson not found".to_string()))
                .await;
            return;
        };
        let mut transcript_rx = lesson.transcript_channel();
        let mut lang_lesson = lesson.get_or_init(query.language.clone()).await;
        let mut translate_rx = lang_lesson.translated_channel();
        let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
        let mut voice_rx = voice_lesson.voice_channel();
        let mut lip_sync_rx = voice_lesson.lip_sync_channel();

        let fut = async {
            loop {
                select! {
                    transcript_poll = transcript_rx.recv() => {
                        let transcript = transcript_poll?;
                        let evt = LiveLessonTextEvent::Transcription {
                            content: transcript.transcript,
                            is_final: transcript.is_final
                        };
                        let Ok(json) = serde_json::to_string(&evt) else {
                            tracing::warn!("failed to serialize: {:?}", evt);
                            continue
                        };
                        tracing::debug!("Transcribed: {}", json);
                        socket.send(Message::Text(json)).await?
                    },
                    translated_poll = translate_rx.recv() => {
                        let translated = translated_poll?;
                        let evt = LiveLessonTextEvent::Translation { content: translated };
                        let Ok(json) = serde_json::to_string(&evt) else {
                            tracing::warn!("failed to serialize: {:?}", evt);
                            continue
                        };
                        tracing::debug!("Translated: {}", json);
                        socket.send(Message::Text(json)).await?
                    },
                    voice_poll = voice_rx.recv() => {
                        let voice = voice_poll?;
                        socket.send(Message::Binary(voice)).await?
                    },
                    visemes_poll = lip_sync_rx.recv() => {
                        let visemes = visemes_poll?;
                        let evt = LiveLessonTextEvent::LipSync { visemes };
                        let Ok(json) = serde_json::to_string(&evt) else {
                            tracing::warn!("failed to serialize: {:?}", evt);
                            continue
                        };
                        socket.send(Message::Text(json)).await?
                    },
                }
            }
        };

        let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
        let _ = span.enter();
        let res: anyhow::Result<()> = fut.await;
        match res {
            Ok(()) => {
                tracing::info!("lesson listener closed");
            }
            Err(e) => {
                tracing::warn!("lesson listener error: {}", e);
            }
        }
    })
}

fn u8_to_i16(input: &[u8]) -> Vec<i16> {
    input
        .chunks_exact(2)
        .map(|chunk| {
            let mut buf = [0u8; 2];
            buf.copy_from_slice(chunk);
            i16::from_le_bytes(buf)
        })
        .collect::<Vec<i16>>()
}