mingyang91 commited on
Commit
2382376
·
verified ·
1 Parent(s): 03e88fd

re-org code structure

Browse files
Files changed (9) hide show
  1. Cargo.lock +12 -12
  2. Cargo.toml +8 -0
  3. src/asr/aws.rs +2 -0
  4. src/asr/mod.rs +1 -2
  5. src/lesson.rs +8 -13
  6. src/lib.rs +400 -0
  7. src/main.rs +3 -494
  8. tests/stream_single.rs +115 -0
  9. whisper/src/handler.rs +0 -1
Cargo.lock CHANGED
@@ -696,9 +696,9 @@ dependencies = [
696
 
697
  [[package]]
698
  name = "config"
699
- version = "0.13.3"
700
  source = "registry+https://github.com/rust-lang/crates.io-index"
701
- checksum = "d379af7f68bfc21714c6c7dea883544201741d2ce8274bb12fa54f89507f52a7"
702
  dependencies = [
703
  "async-trait",
704
  "json5",
@@ -838,9 +838,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
838
 
839
  [[package]]
840
  name = "form_urlencoded"
841
- version = "1.2.0"
842
  source = "registry+https://github.com/rust-lang/crates.io-index"
843
- checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652"
844
  dependencies = [
845
  "percent-encoding",
846
  ]
@@ -853,9 +853,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
853
 
854
  [[package]]
855
  name = "futures"
856
- version = "0.3.28"
857
  source = "registry+https://github.com/rust-lang/crates.io-index"
858
- checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
859
  dependencies = [
860
  "futures-channel",
861
  "futures-core",
@@ -1179,9 +1179,9 @@ dependencies = [
1179
 
1180
  [[package]]
1181
  name = "idna"
1182
- version = "0.4.0"
1183
  source = "registry+https://github.com/rust-lang/crates.io-index"
1184
- checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
1185
  dependencies = [
1186
  "unicode-bidi",
1187
  "unicode-normalization",
@@ -1487,9 +1487,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
1487
 
1488
  [[package]]
1489
  name = "percent-encoding"
1490
- version = "2.3.0"
1491
  source = "registry+https://github.com/rust-lang/crates.io-index"
1492
- checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
1493
 
1494
  [[package]]
1495
  name = "pest"
@@ -2535,9 +2535,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
2535
 
2536
  [[package]]
2537
  name = "url"
2538
- version = "2.4.1"
2539
  source = "registry+https://github.com/rust-lang/crates.io-index"
2540
- checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5"
2541
  dependencies = [
2542
  "form_urlencoded",
2543
  "idna",
 
696
 
697
  [[package]]
698
  name = "config"
699
+ version = "0.13.4"
700
  source = "registry+https://github.com/rust-lang/crates.io-index"
701
+ checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca"
702
  dependencies = [
703
  "async-trait",
704
  "json5",
 
838
 
839
  [[package]]
840
  name = "form_urlencoded"
841
+ version = "1.2.1"
842
  source = "registry+https://github.com/rust-lang/crates.io-index"
843
+ checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
844
  dependencies = [
845
  "percent-encoding",
846
  ]
 
853
 
854
  [[package]]
855
  name = "futures"
856
+ version = "0.3.29"
857
  source = "registry+https://github.com/rust-lang/crates.io-index"
858
+ checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335"
859
  dependencies = [
860
  "futures-channel",
861
  "futures-core",
 
1179
 
1180
  [[package]]
1181
  name = "idna"
1182
+ version = "0.5.0"
1183
  source = "registry+https://github.com/rust-lang/crates.io-index"
1184
+ checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
1185
  dependencies = [
1186
  "unicode-bidi",
1187
  "unicode-normalization",
 
1487
 
1488
  [[package]]
1489
  name = "percent-encoding"
1490
+ version = "2.3.1"
1491
  source = "registry+https://github.com/rust-lang/crates.io-index"
1492
+ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
1493
 
1494
  [[package]]
1495
  name = "pest"
 
2535
 
2536
  [[package]]
2537
  name = "url"
2538
+ version = "2.5.0"
2539
  source = "registry+https://github.com/rust-lang/crates.io-index"
2540
+ checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
2541
  dependencies = [
2542
  "form_urlencoded",
2543
  "idna",
Cargo.toml CHANGED
@@ -3,6 +3,14 @@ name = "polyhedron"
3
  version = "0.1.0"
4
  edition = "2021"
5
 
 
 
 
 
 
 
 
 
6
  [dependencies]
7
  anyhow = "1.0"
8
  async-stream = "0.3"
 
3
  version = "0.1.0"
4
  edition = "2021"
5
 
6
+ [lib]
7
+ name = "polyhedron"
8
+ path = "src/lib.rs"
9
+
10
+ [[bin]]
11
+ name = "polyhedron"
12
+ path = "src/main.rs"
13
+
14
  [dependencies]
15
  anyhow = "1.0"
16
  async-stream = "0.3"
src/asr/aws.rs CHANGED
@@ -12,6 +12,7 @@ use tokio::select;
12
  use tokio::sync::broadcast::Receiver;
13
  use tokio_stream::Stream;
14
  use futures_util::TryStreamExt;
 
15
  use crate::asr::{ASR, Event, slice_i16_to_u8};
16
 
17
  pub struct AwsAsr {
@@ -58,6 +59,7 @@ impl AwsAsr {
58
  let output_stream = to_stream(output);
59
  output_stream
60
  .try_for_each(|text| async {
 
61
  let _ = shared_speaker_transcript.send(text);
62
  Ok(())
63
  })
 
12
  use tokio::sync::broadcast::Receiver;
13
  use tokio_stream::Stream;
14
  use futures_util::TryStreamExt;
15
+ use tracing::trace;
16
  use crate::asr::{ASR, Event, slice_i16_to_u8};
17
 
18
  pub struct AwsAsr {
 
59
  let output_stream = to_stream(output);
60
  output_stream
61
  .try_for_each(|text| async {
62
+ trace!("here");
63
  let _ = shared_speaker_transcript.send(text);
64
  Ok(())
65
  })
src/asr/mod.rs CHANGED
@@ -18,8 +18,7 @@ pub(crate) trait ASR {
18
  }
19
 
20
 
21
- #[allow(dead_code)]
22
- pub(crate) fn slice_i16_to_u8(slice: &[i16]) -> Vec<u8> {
23
  slice
24
  .iter()
25
  .flat_map(|&sample| {
 
18
  }
19
 
20
 
21
+ pub fn slice_i16_to_u8(slice: &[i16]) -> Vec<u8> {
 
22
  slice
23
  .iter()
24
  .flat_map(|&sample| {
src/lesson.rs CHANGED
@@ -1,9 +1,7 @@
1
  use aws_config::SdkConfig;
2
- use aws_sdk_polly::primitives::ByteStream;
3
  use aws_sdk_polly::types::{Engine, OutputFormat, SpeechMarkType, VoiceId};
4
  use aws_sdk_transcribestreaming::types::{LanguageCode};
5
  use futures_util::future::try_join;
6
- use serde::{Deserialize, Serialize};
7
  use std::collections::BTreeMap;
8
  use std::fmt::{Debug, Formatter};
9
  use std::io::BufRead;
@@ -17,11 +15,14 @@ use crate::asr::{Event, aws::AwsAsr, ASR};
17
 
18
  #[cfg(feature = "whisper")]
19
  use crate::asr::whisper::WhisperAsr;
 
 
 
20
 
21
  pub struct InnerLessonsManager {
22
  translate_client: aws_sdk_translate::Client,
23
  polly_client: aws_sdk_polly::Client,
24
- lessons: Arc<RwLock<BTreeMap<u32, Lesson>>>,
25
  }
26
 
27
  #[derive(Clone)]
@@ -43,7 +44,7 @@ impl Deref for LessonsManager {
43
  }
44
  }
45
 
46
- pub(crate) enum AsrEngine {
47
  AWS,
48
  #[allow(dead_code)]
49
  #[cfg(feature = "whisper")]
@@ -69,7 +70,7 @@ impl AsrEngine {
69
  }
70
 
71
  impl LessonsManager {
72
- pub(crate) fn new(sdk_config: &SdkConfig) -> Self {
73
  let translate_client = aws_sdk_translate::Client::new(sdk_config);
74
  let polly_client = aws_sdk_polly::Client::new(sdk_config);
75
  let inner = InnerLessonsManager {
@@ -80,14 +81,14 @@ impl LessonsManager {
80
  LessonsManager { inner: Arc::new(inner) }
81
  }
82
 
83
- pub(crate) async fn create_lesson(&self, id: u32, engine: AsrEngine, speaker_lang: LanguageCode) -> Lesson {
84
  let mut map = self.lessons.write().await;
85
  let lesson: Lesson = InnerLesson::new(self.clone(), engine, speaker_lang).await.into();
86
  map.insert(id, lesson.clone());
87
  lesson
88
  }
89
 
90
- pub(crate) async fn get_lesson(&self, id: u32) -> Option<Lesson> {
91
  let map = self.lessons.read().await;
92
  map.get(&id).cloned()
93
  }
@@ -434,12 +435,6 @@ impl Drop for InnerVoiceLesson {
434
  }
435
 
436
 
437
- // {"time":180,"type":"viseme","value":"r"}
438
- #[derive(Debug, Deserialize, Clone, Serialize)]
439
- pub(crate) struct Viseme {
440
- time: u32,
441
- value: String,
442
- }
443
 
444
  #[derive(Debug)]
445
  enum SynthesizeError {
 
1
  use aws_config::SdkConfig;
 
2
  use aws_sdk_polly::types::{Engine, OutputFormat, SpeechMarkType, VoiceId};
3
  use aws_sdk_transcribestreaming::types::{LanguageCode};
4
  use futures_util::future::try_join;
 
5
  use std::collections::BTreeMap;
6
  use std::fmt::{Debug, Formatter};
7
  use std::io::BufRead;
 
15
 
16
  #[cfg(feature = "whisper")]
17
  use crate::asr::whisper::WhisperAsr;
18
+ use crate::Viseme;
19
+
20
+ pub type LessonID = String;
21
 
22
  pub struct InnerLessonsManager {
23
  translate_client: aws_sdk_translate::Client,
24
  polly_client: aws_sdk_polly::Client,
25
+ lessons: Arc<RwLock<BTreeMap<LessonID, Lesson>>>,
26
  }
27
 
28
  #[derive(Clone)]
 
44
  }
45
  }
46
 
47
+ pub enum AsrEngine {
48
  AWS,
49
  #[allow(dead_code)]
50
  #[cfg(feature = "whisper")]
 
70
  }
71
 
72
  impl LessonsManager {
73
+ pub fn new(sdk_config: &SdkConfig) -> Self {
74
  let translate_client = aws_sdk_translate::Client::new(sdk_config);
75
  let polly_client = aws_sdk_polly::Client::new(sdk_config);
76
  let inner = InnerLessonsManager {
 
81
  LessonsManager { inner: Arc::new(inner) }
82
  }
83
 
84
+ pub(crate) async fn create_lesson(&self, id: LessonID, engine: AsrEngine, speaker_lang: LanguageCode) -> Lesson {
85
  let mut map = self.lessons.write().await;
86
  let lesson: Lesson = InnerLesson::new(self.clone(), engine, speaker_lang).await.into();
87
  map.insert(id, lesson.clone());
88
  lesson
89
  }
90
 
91
+ pub(crate) async fn get_lesson(&self, id: LessonID) -> Option<Lesson> {
92
  let map = self.lessons.read().await;
93
  map.get(&id).cloned()
94
  }
 
435
  }
436
 
437
 
 
 
 
 
 
 
438
 
439
  #[derive(Debug)]
440
  enum SynthesizeError {
src/lib.rs ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[cfg(feature = "whisper")]
2
+ extern crate whisper;
3
+
4
+ pub mod config;
5
+ pub mod lesson;
6
+ pub mod asr;
7
+ pub mod base64box;
8
+
9
+ use aws_config::{SdkConfig};
10
+ use aws_sdk_transcribestreaming::types::LanguageCode;
11
+ use futures_util::{stream::StreamExt, SinkExt};
12
+ use poem::{
13
+ endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
14
+ get, handler,
15
+ listener::TcpListener,
16
+ web::{
17
+ websocket::{Message, WebSocket},
18
+ Data, Query,
19
+ },
20
+ EndpointExt, IntoResponse, Route, Server,
21
+ };
22
+ use serde::{Deserialize, Serialize};
23
+ use tokio::select;
24
+ use tracing::{debug, span};
25
+
26
+ use crate::base64box::Base64Box;
27
+ use crate::config::SETTINGS;
28
+ use crate::lesson::{AsrEngine, LessonID, LessonsManager};
29
+
30
+ #[derive(Clone)]
31
+ pub struct Context {
32
+ lessons_manager: LessonsManager,
33
+ }
34
+
35
+ impl Context {
36
+ pub fn new(config: &SdkConfig) -> Self {
37
+ Self {
38
+ lessons_manager: LessonsManager::new(config),
39
+ }
40
+ }
41
+ }
42
+
43
+ #[derive(Serialize, Deserialize, Debug)]
44
+ pub enum SingleEvent {
45
+ #[serde(rename = "original")]
46
+ Transcription {
47
+ content: String,
48
+ #[serde(rename = "isFinal")]
49
+ is_final: bool
50
+ },
51
+ #[serde(rename = "translated")]
52
+ Translation { content: String },
53
+ #[serde(rename = "voice")]
54
+ Voice {
55
+ content: Base64Box
56
+ },
57
+ }
58
+
59
+
60
+ #[derive(Deserialize, Debug)]
61
+ pub struct LessonListenerQuery {
62
+ id: String,
63
+ language: String,
64
+ voice: String,
65
+ }
66
+
67
+ #[derive(Serialize, Debug)]
68
+ #[serde(tag = "type")]
69
+ pub enum LiveLessonTextEvent {
70
+ #[serde(rename = "original")]
71
+ Transcription {
72
+ content: String,
73
+ #[serde(rename = "isFinal")]
74
+ is_final: bool
75
+ },
76
+ Translation { content: String },
77
+ LipSync { visemes: Vec<Viseme> },
78
+ }
79
+
80
+
81
+ // {"time":180,"type":"viseme","value":"r"}
82
+ #[derive(Debug, Deserialize, Clone, Serialize)]
83
+ pub struct Viseme {
84
+ time: u32,
85
+ value: String,
86
+ }
87
+
88
+ #[derive(Deserialize, Debug)]
89
+ pub struct LessonSpeakerQuery {
90
+ id: LessonID,
91
+ language: String,
92
+ #[allow(dead_code)] // TODO: use this in the future
93
+ prompt: Option<String>,
94
+ }
95
+
96
+ #[handler]
97
+ async fn stream_speaker(
98
+ ctx: Data<&Context>,
99
+ query: Query<LessonSpeakerQuery>,
100
+ ws: WebSocket,
101
+ ) -> impl IntoResponse {
102
+ let lessons_manager = ctx.lessons_manager.clone();
103
+ ws.on_upgrade(|mut socket| async move {
104
+ let Ok(lang) = query.language.parse::<LanguageCode>() else {
105
+ let _ = socket
106
+ .send(Message::Text(format!("invalid language code: {}", query.language)))
107
+ .await;
108
+ return
109
+ };
110
+ let lesson = lessons_manager
111
+ .create_lesson(
112
+ query.id.clone(),
113
+ AsrEngine::AWS,
114
+ lang,
115
+ )
116
+ .await;
117
+
118
+ let mut transcribe_rx = lesson.transcript_channel();
119
+ let fut = async {
120
+ loop {
121
+ select! {
122
+ msg = socket.next() => {
123
+ let Some(res) = msg else { break };
124
+ let msg = res?;
125
+ if msg.is_close() {
126
+ break
127
+ }
128
+ let Message::Binary(bin) = msg else {
129
+ tracing::warn!("Other: {:?}", msg);
130
+ continue
131
+ };
132
+ let frame = u8_to_i16(&bin);
133
+ lesson.send(frame).await?
134
+ },
135
+ output = transcribe_rx.recv() => {
136
+ let evt = output?;
137
+ if evt.is_final {
138
+ tracing::trace!("Transcribed: {}", evt.transcript);
139
+ }
140
+ let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
141
+ let Ok(json) = serde_json::to_string(&evt) else {
142
+ tracing::warn!("failed to serialize json: {:?}", evt);
143
+ continue
144
+ };
145
+ socket.send(Message::Text(json)).await?
146
+ },
147
+ }
148
+ }
149
+ Ok(())
150
+ };
151
+
152
+ let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
153
+ let _ = span.enter();
154
+ let res: anyhow::Result<()> = fut.await;
155
+ match res {
156
+ Ok(()) => {
157
+ tracing::info!("lesson speaker closed");
158
+ }
159
+ Err(e) => {
160
+ tracing::warn!("lesson speaker error: {}", e);
161
+ }
162
+ }
163
+ })
164
+ }
165
+
166
+ #[handler]
167
+ async fn stream_listener(
168
+ ctx: Data<&Context>,
169
+ query: Query<LessonListenerQuery>,
170
+ ws: WebSocket,
171
+ ) -> impl IntoResponse {
172
+ let lessons_manager = ctx.lessons_manager.clone();
173
+
174
+ ws.on_upgrade(|mut socket| async move {
175
+ let voice_id = match query.voice.parse() {
176
+ Ok(id) => id,
177
+ Err(e) => {
178
+ let _ = socket
179
+ .send(Message::Text(format!("invalid voice id: {}", e)))
180
+ .await;
181
+ return;
182
+ }
183
+ };
184
+
185
+ let lesson_opt = lessons_manager.get_lesson(query.id.clone()).await;
186
+ debug!("listener param = {:?}", query);
187
+ let Some(lesson) = lesson_opt else {
188
+ let _ = socket
189
+ .send(Message::Text("lesson not found".to_string()))
190
+ .await;
191
+ return;
192
+ };
193
+ let mut transcript_rx = lesson.transcript_channel();
194
+ let mut lang_lesson = lesson.get_or_init(query.language.clone()).await;
195
+ let mut translate_rx = lang_lesson.translated_channel();
196
+ let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
197
+ let mut voice_rx = voice_lesson.voice_channel();
198
+ let mut lip_sync_rx = voice_lesson.lip_sync_channel();
199
+
200
+ let fut = async {
201
+ loop {
202
+ select! {
203
+ transcript_poll = transcript_rx.recv() => {
204
+ let transcript = transcript_poll?;
205
+ let evt = LiveLessonTextEvent::Transcription {
206
+ content: transcript.transcript,
207
+ is_final: transcript.is_final
208
+ };
209
+ let Ok(json) = serde_json::to_string(&evt) else {
210
+ tracing::warn!("failed to serialize: {:?}", evt);
211
+ continue
212
+ };
213
+ tracing::debug!("Transcribed: {}", json);
214
+ socket.send(Message::Text(json)).await?
215
+ },
216
+ translated_poll = translate_rx.recv() => {
217
+ let translated = translated_poll?;
218
+ let evt = LiveLessonTextEvent::Translation { content: translated };
219
+ let Ok(json) = serde_json::to_string(&evt) else {
220
+ tracing::warn!("failed to serialize: {:?}", evt);
221
+ continue
222
+ };
223
+ tracing::debug!("Translated: {}", json);
224
+ socket.send(Message::Text(json)).await?
225
+ },
226
+ voice_poll = voice_rx.recv() => {
227
+ let voice = voice_poll?;
228
+ socket.send(Message::Binary(voice)).await?
229
+ },
230
+ visemes_poll = lip_sync_rx.recv() => {
231
+ let visemes = visemes_poll?;
232
+ let evt = LiveLessonTextEvent::LipSync { visemes };
233
+ let Ok(json) = serde_json::to_string(&evt) else {
234
+ tracing::warn!("failed to serialize: {:?}", evt);
235
+ continue
236
+ };
237
+ socket.send(Message::Text(json)).await?
238
+ },
239
+ }
240
+ }
241
+ };
242
+
243
+ let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
244
+ let _ = span.enter();
245
+ let res: anyhow::Result<()> = fut.await;
246
+ match res {
247
+ Ok(()) => {
248
+ tracing::info!("lesson listener closed");
249
+ }
250
+ Err(e) => {
251
+ tracing::warn!("lesson listener error: {}", e);
252
+ }
253
+ }
254
+ })
255
+ }
256
+
257
+ #[derive(Deserialize, Debug)]
258
+ pub struct SingleQuery {
259
+ id: LessonID,
260
+ from: String,
261
+ to: String,
262
+ voice: Option<String>,
263
+ }
264
+
265
+
266
+ #[handler]
267
+ pub async fn stream_single(
268
+ ctx: Data<&Context>,
269
+ query: Query<SingleQuery>,
270
+ ws: WebSocket
271
+ ) -> impl IntoResponse {
272
+ let lessons_manager = ctx.lessons_manager.clone();
273
+ ws.on_upgrade(|mut socket| async move {
274
+ let Ok(lang) = query.from.parse::<LanguageCode>() else {
275
+ let _ = socket
276
+ .send(Message::Text(format!("invalid language code: {}", query.from)))
277
+ .await;
278
+ return
279
+ };
280
+ let lesson = lessons_manager
281
+ .create_lesson(
282
+ query.id.clone(),
283
+ AsrEngine::AWS,
284
+ lang,
285
+ )
286
+ .await;
287
+
288
+ let mut transcribe_rx = lesson.transcript_channel();
289
+ let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
290
+ let mut translate_rx = lang_lesson.translated_channel();
291
+ let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
292
+ let _ = socket
293
+ .send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
294
+ .await;
295
+ return
296
+ };
297
+ let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
298
+ let mut voice_rx = voice_lesson.voice_channel();
299
+ // let mut lip_sync_rx = voice_lesson.lip_sync_channel();
300
+
301
+ let fut = async {
302
+ loop {
303
+ let evt = select! {
304
+ input = socket.next() => {
305
+ let Some(res) = input else { break };
306
+ let msg = res?;
307
+ if msg.is_close() {
308
+ break
309
+ }
310
+ let Message::Binary(bin) = msg else {
311
+ tracing::warn!("Other: {:?}", msg);
312
+ continue
313
+ };
314
+ let frame = u8_to_i16(&bin);
315
+ lesson.send(frame).await?;
316
+ continue
317
+ },
318
+ transcript_poll = transcribe_rx.recv() => {
319
+ let evt = transcript_poll?;
320
+ if evt.is_final {
321
+ tracing::trace!("Transcribed: {}", evt.transcript);
322
+ }
323
+ SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
324
+ },
325
+ translated_poll = translate_rx.recv() => {
326
+ let translated = translated_poll?;
327
+ SingleEvent::Translation { content: translated }
328
+ },
329
+ voice_poll = voice_rx.recv() => {
330
+ let voice = voice_poll?;
331
+ SingleEvent::Voice { content: Base64Box(voice) }
332
+ },
333
+ };
334
+
335
+ let Ok(json) = serde_json::to_string(&evt) else {
336
+ tracing::warn!("failed to serialize json: {:?}", evt);
337
+ continue
338
+ };
339
+ socket.send(Message::Text(json)).await?
340
+ }
341
+ Ok(())
342
+ };
343
+
344
+ let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
345
+ let _ = span.enter();
346
+ let res: anyhow::Result<()> = fut.await;
347
+ match res {
348
+ Ok(()) => {
349
+ tracing::info!("lesson speaker closed");
350
+ }
351
+ Err(e) => {
352
+ tracing::warn!("lesson speaker error: {}", e);
353
+ }
354
+ }
355
+ })
356
+ }
357
+
358
+ fn u8_to_i16(input: &[u8]) -> Vec<i16> {
359
+ input
360
+ .chunks_exact(2)
361
+ .map(|chunk| {
362
+ let mut buf = [0u8; 2];
363
+ buf.copy_from_slice(chunk);
364
+ i16::from_le_bytes(buf)
365
+ })
366
+ .collect::<Vec<i16>>()
367
+ }
368
+
369
+ pub async fn app(config: &SdkConfig) -> Result<(), std::io::Error> {
370
+ let ctx = Context {
371
+ lessons_manager: LessonsManager::new(config),
372
+ };
373
+
374
+ let app = Route::new()
375
+ .nest(
376
+ "/",
377
+ StaticFilesEndpoint::new("./static")
378
+ .show_files_listing()
379
+ .index_file("index.html"),
380
+ )
381
+ .at("/ws/lesson-speaker", get(stream_speaker))
382
+ .at("/ws/teacher", get(stream_speaker))
383
+ .at("/ws/lesson-listener", get(stream_listener))
384
+ .at("/ws/student", get(stream_listener))
385
+ .at("/ws/voice", get(stream_single))
386
+ .at(
387
+ "lesson-speaker",
388
+ StaticFileEndpoint::new("./static/index.html"),
389
+ )
390
+ .at(
391
+ "lesson-listener",
392
+ StaticFileEndpoint::new("./static/index.html"),
393
+ )
394
+ .data(ctx);
395
+ let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
396
+ let listener = TcpListener::bind(addr);
397
+ let server = Server::new(listener);
398
+
399
+ server.run(app).await
400
+ }
src/main.rs CHANGED
@@ -3,42 +3,14 @@
3
  * SPDX-License-Identifier: Apache-2.0.
4
  */
5
 
6
- #![allow(clippy::result_large_err)]
7
-
8
- #[cfg(feature = "whisper")]
9
- extern crate whisper;
10
-
11
  use aws_config::BehaviorVersion;
12
  use aws_sdk_transcribestreaming::meta::PKG_VERSION;
13
- use aws_sdk_transcribestreaming::types::LanguageCode;
14
- use futures_util::{stream::StreamExt, SinkExt};
15
- use poem::{
16
- endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
17
- get, handler,
18
- listener::TcpListener,
19
- web::{
20
- websocket::{Message, WebSocket},
21
- Data, Query,
22
- },
23
- EndpointExt, IntoResponse, Route, Server,
24
- };
25
- use serde::{Deserialize, Serialize};
26
  use tokio::select;
27
- use tracing::{debug, span};
28
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
29
 
30
- use crate::{config::*, lesson::*};
31
- use crate::base64box::Base64Box;
32
 
33
- mod config;
34
- mod lesson;
35
- mod asr;
36
- mod base64box;
37
-
38
- #[derive(Clone)]
39
- struct Context {
40
- lessons_manager: LessonsManager,
41
- }
42
 
43
  #[tokio::main]
44
  async fn main() -> Result<(), std::io::Error> {
@@ -50,476 +22,13 @@ async fn main() -> Result<(), std::io::Error> {
50
  debug!("Transcribe client version: {}", PKG_VERSION);
51
 
52
  let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
53
- let ctx = Context {
54
- lessons_manager: LessonsManager::new(&shared_config),
55
- };
56
 
57
- let app = Route::new()
58
- .nest(
59
- "/",
60
- StaticFilesEndpoint::new("./static")
61
- .show_files_listing()
62
- .index_file("index.html"),
63
- )
64
- .at("/ws/lesson-speaker", get(stream_speaker))
65
- .at("/ws/teacher", get(stream_speaker))
66
- .at("/ws/lesson-listener", get(stream_listener))
67
- .at("/ws/student", get(stream_listener))
68
- .at("/ws/voice", get(stream_single))
69
- .at(
70
- "lesson-speaker",
71
- StaticFileEndpoint::new("./static/index.html"),
72
- )
73
- .at(
74
- "lesson-listener",
75
- StaticFileEndpoint::new("./static/index.html"),
76
- )
77
- .data(ctx);
78
- let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
79
- let listener = TcpListener::bind(addr);
80
- let server = Server::new(listener);
81
 
82
  select! {
83
- res = server.run(app) => res,
84
  _ = tokio::signal::ctrl_c() => {
85
  tracing::info!("Shutting down");
86
  Ok(())
87
  },
88
  }
89
  }
90
-
91
- #[derive(Deserialize, Debug)]
92
- pub struct LessonSpeakerQuery {
93
- id: u32,
94
- language: String,
95
- #[allow(dead_code)] // TODO: use this in the future
96
- prompt: Option<String>,
97
- }
98
-
99
- #[handler]
100
- async fn stream_speaker(
101
- ctx: Data<&Context>,
102
- query: Query<LessonSpeakerQuery>,
103
- ws: WebSocket,
104
- ) -> impl IntoResponse {
105
- let lessons_manager = ctx.lessons_manager.clone();
106
- ws.on_upgrade(|mut socket| async move {
107
- let Ok(lang) = query.language.parse::<LanguageCode>() else {
108
- let _ = socket
109
- .send(Message::Text(format!("invalid language code: {}", query.language)))
110
- .await;
111
- return
112
- };
113
- let lesson = lessons_manager
114
- .create_lesson(
115
- query.id,
116
- AsrEngine::AWS,
117
- lang,
118
- )
119
- .await;
120
-
121
- let mut transcribe_rx = lesson.transcript_channel();
122
- let fut = async {
123
- loop {
124
- select! {
125
- msg = socket.next() => {
126
- let Some(res) = msg else { break };
127
- let msg = res?;
128
- if msg.is_close() {
129
- break
130
- }
131
- let Message::Binary(bin) = msg else {
132
- tracing::warn!("Other: {:?}", msg);
133
- continue
134
- };
135
- let frame = u8_to_i16(&bin);
136
- lesson.send(frame).await?
137
- },
138
- output = transcribe_rx.recv() => {
139
- let evt = output?;
140
- if evt.is_final {
141
- tracing::trace!("Transcribed: {}", evt.transcript);
142
- }
143
- let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
144
- let Ok(json) = serde_json::to_string(&evt) else {
145
- tracing::warn!("failed to serialize json: {:?}", evt);
146
- continue
147
- };
148
- socket.send(Message::Text(json)).await?
149
- },
150
- }
151
- }
152
- Ok(())
153
- };
154
-
155
- let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
156
- let _ = span.enter();
157
- let res: anyhow::Result<()> = fut.await;
158
- match res {
159
- Ok(()) => {
160
- tracing::info!("lesson speaker closed");
161
- }
162
- Err(e) => {
163
- tracing::warn!("lesson speaker error: {}", e);
164
- }
165
- }
166
- })
167
- }
168
-
169
- #[derive(Deserialize, Debug)]
170
- pub struct LessonListenerQuery {
171
- id: u32,
172
- language: String,
173
- voice: String,
174
- }
175
-
176
- #[derive(Serialize, Debug)]
177
- #[serde(tag = "type")]
178
- enum LiveLessonTextEvent {
179
- #[serde(rename = "original")]
180
- Transcription {
181
- content: String,
182
- #[serde(rename = "isFinal")]
183
- is_final: bool
184
- },
185
- Translation { content: String },
186
- LipSync { visemes: Vec<Viseme> },
187
- }
188
- #[handler]
189
- async fn stream_listener(
190
- ctx: Data<&Context>,
191
- query: Query<LessonListenerQuery>,
192
- ws: WebSocket,
193
- ) -> impl IntoResponse {
194
- let lessons_manager = ctx.lessons_manager.clone();
195
-
196
- ws.on_upgrade(|mut socket| async move {
197
- let voice_id = match query.voice.parse() {
198
- Ok(id) => id,
199
- Err(e) => {
200
- let _ = socket
201
- .send(Message::Text(format!("invalid voice id: {}", e)))
202
- .await;
203
- return;
204
- }
205
- };
206
-
207
- let lesson_opt = lessons_manager.get_lesson(query.id).await;
208
- debug!("listener param = {:?}", query);
209
- let Some(lesson) = lesson_opt else {
210
- let _ = socket
211
- .send(Message::Text("lesson not found".to_string()))
212
- .await;
213
- return;
214
- };
215
- let mut transcript_rx = lesson.transcript_channel();
216
- let mut lang_lesson = lesson.get_or_init(query.language.clone()).await;
217
- let mut translate_rx = lang_lesson.translated_channel();
218
- let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
219
- let mut voice_rx = voice_lesson.voice_channel();
220
- let mut lip_sync_rx = voice_lesson.lip_sync_channel();
221
-
222
- let fut = async {
223
- loop {
224
- select! {
225
- transcript_poll = transcript_rx.recv() => {
226
- let transcript = transcript_poll?;
227
- let evt = LiveLessonTextEvent::Transcription {
228
- content: transcript.transcript,
229
- is_final: transcript.is_final
230
- };
231
- let Ok(json) = serde_json::to_string(&evt) else {
232
- tracing::warn!("failed to serialize: {:?}", evt);
233
- continue
234
- };
235
- tracing::debug!("Transcribed: {}", json);
236
- socket.send(Message::Text(json)).await?
237
- },
238
- translated_poll = translate_rx.recv() => {
239
- let translated = translated_poll?;
240
- let evt = LiveLessonTextEvent::Translation { content: translated };
241
- let Ok(json) = serde_json::to_string(&evt) else {
242
- tracing::warn!("failed to serialize: {:?}", evt);
243
- continue
244
- };
245
- tracing::debug!("Translated: {}", json);
246
- socket.send(Message::Text(json)).await?
247
- },
248
- voice_poll = voice_rx.recv() => {
249
- let voice = voice_poll?;
250
- socket.send(Message::Binary(voice)).await?
251
- },
252
- visemes_poll = lip_sync_rx.recv() => {
253
- let visemes = visemes_poll?;
254
- let evt = LiveLessonTextEvent::LipSync { visemes };
255
- let Ok(json) = serde_json::to_string(&evt) else {
256
- tracing::warn!("failed to serialize: {:?}", evt);
257
- continue
258
- };
259
- socket.send(Message::Text(json)).await?
260
- },
261
- }
262
- }
263
- };
264
-
265
- let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
266
- let _ = span.enter();
267
- let res: anyhow::Result<()> = fut.await;
268
- match res {
269
- Ok(()) => {
270
- tracing::info!("lesson listener closed");
271
- }
272
- Err(e) => {
273
- tracing::warn!("lesson listener error: {}", e);
274
- }
275
- }
276
- })
277
- }
278
-
279
- #[derive(Serialize, Deserialize, Debug)]
280
- enum SingleEvent {
281
- #[serde(rename = "original")]
282
- Transcription {
283
- content: String,
284
- #[serde(rename = "isFinal")]
285
- is_final: bool
286
- },
287
- #[serde(rename = "translated")]
288
- Translation { content: String },
289
- #[serde(rename = "voice")]
290
- Voice {
291
- content: Base64Box
292
- },
293
- }
294
-
295
-
296
- #[derive(Deserialize, Debug)]
297
- pub struct SingleQuery {
298
- id: u32,
299
- from: String,
300
- to: String,
301
- voice: Option<String>,
302
- }
303
-
304
-
305
- #[handler]
306
- async fn stream_single(
307
- ctx: Data<&Context>,
308
- query: Query<SingleQuery>,
309
- ws: WebSocket
310
- ) -> impl IntoResponse {
311
- let lessons_manager = ctx.lessons_manager.clone();
312
- ws.on_upgrade(|mut socket| async move {
313
- let Ok(lang) = query.from.parse::<LanguageCode>() else {
314
- let _ = socket
315
- .send(Message::Text(format!("invalid language code: {}", query.from)))
316
- .await;
317
- return
318
- };
319
- let lesson = lessons_manager
320
- .create_lesson(
321
- query.id,
322
- AsrEngine::AWS,
323
- lang,
324
- )
325
- .await;
326
-
327
- let mut transcribe_rx = lesson.transcript_channel();
328
- let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
329
- let mut translate_rx = lang_lesson.translated_channel();
330
- let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
331
- let _ = socket
332
- .send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
333
- .await;
334
- return
335
- };
336
- let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
337
- let mut voice_rx = voice_lesson.voice_channel();
338
- // let mut lip_sync_rx = voice_lesson.lip_sync_channel();
339
-
340
- let fut = async {
341
- loop {
342
- let evt = select! {
343
- input = socket.next() => {
344
- let Some(res) = input else { break };
345
- let msg = res?;
346
- if msg.is_close() {
347
- break
348
- }
349
- let Message::Binary(bin) = msg else {
350
- tracing::warn!("Other: {:?}", msg);
351
- continue
352
- };
353
- let frame = u8_to_i16(&bin);
354
- lesson.send(frame).await?;
355
- continue
356
- },
357
- transcript_poll = transcribe_rx.recv() => {
358
- let evt = transcript_poll?;
359
- if evt.is_final {
360
- tracing::trace!("Transcribed: {}", evt.transcript);
361
- }
362
- SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
363
- },
364
- translated_poll = translate_rx.recv() => {
365
- let translated = translated_poll?;
366
- SingleEvent::Translation { content: translated }
367
- },
368
- voice_poll = voice_rx.recv() => {
369
- let voice = voice_poll?;
370
- SingleEvent::Voice { content: Base64Box(voice) }
371
- },
372
- };
373
-
374
- let Ok(json) = serde_json::to_string(&evt) else {
375
- tracing::warn!("failed to serialize json: {:?}", evt);
376
- continue
377
- };
378
- socket.send(Message::Text(json)).await?
379
- }
380
- Ok(())
381
- };
382
-
383
- let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
384
- let _ = span.enter();
385
- let res: anyhow::Result<()> = fut.await;
386
- match res {
387
- Ok(()) => {
388
- tracing::info!("lesson speaker closed");
389
- }
390
- Err(e) => {
391
- tracing::warn!("lesson speaker error: {}", e);
392
- }
393
- }
394
- })
395
- }
396
-
397
- fn u8_to_i16(input: &[u8]) -> Vec<i16> {
398
- input
399
- .chunks_exact(2)
400
- .map(|chunk| {
401
- let mut buf = [0u8; 2];
402
- buf.copy_from_slice(chunk);
403
- i16::from_le_bytes(buf)
404
- })
405
- .collect::<Vec<i16>>()
406
- }
407
-
408
- #[cfg(test)]
409
- mod test {
410
- use std::time::Duration;
411
- use async_stream::stream;
412
- use poem::listener::{Acceptor, Listener};
413
- use tokio::pin;
414
- use tokio::time::sleep;
415
- use tokio_stream::StreamExt;
416
- use tokio_tungstenite::{
417
- connect_async,
418
- tungstenite::Message,
419
- };
420
- use tracing::{info, error};
421
- use crate::asr::slice_i16_to_u8;
422
- use super::*;
423
-
424
- #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
425
- #[tracing_test::traced_test]
426
- async fn test_single() {
427
- let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
428
- let ctx = Context {
429
- lessons_manager: LessonsManager::new(&shared_config),
430
- };
431
-
432
- let query = SingleQuery {
433
- id: 1,
434
- from: "zh-CN".to_string(),
435
- to: "en-US".to_string(),
436
- voice: None,
437
- };
438
-
439
- let acceptor = TcpListener::bind("[::]:0")
440
- .into_acceptor()
441
- .await
442
- .unwrap();
443
- let addr = acceptor
444
- .local_addr()
445
- .remove(0)
446
- .as_socket_addr()
447
- .cloned()
448
- .unwrap();
449
- let server = Server::new_with_acceptor(acceptor);
450
- let handle = tokio::spawn(async move {
451
- let _ = server.run(
452
- Route::new()
453
- .at("/ws/voice", get(stream_single))
454
- .data(ctx)
455
- ).await;
456
- });
457
-
458
- let url = format!(
459
- "ws://{}/ws/voice?id={}&from={}&to={}&voice={}",
460
- addr, query.id, query.from, query.to, query.voice.unwrap_or("Amy".to_string())
461
- );
462
- let (mut client_stream, _) = connect_async(url)
463
- .await
464
- .unwrap();
465
-
466
- client_stream
467
- .send(Message::Binary(Vec::new()))
468
- .await
469
- .unwrap();
470
-
471
-
472
- let wav = hound::WavReader::open("whisper/samples/samples_jfk.wav")
473
- .expect("failed to open wav");
474
- let spec = wav.spec();
475
- info!("{:?}", spec);
476
- let samples = wav
477
- .into_samples::<i16>()
478
- .map(|s| s.unwrap())
479
- .collect::<Vec<i16>>();
480
- let chunks = samples.chunks(1600)
481
- .map(|chunk| chunk.to_vec())
482
- .into_iter();
483
-
484
- let audio_stream = stream! {
485
- for chunk in chunks {
486
- yield slice_i16_to_u8(&chunk);
487
- sleep(Duration::from_millis(10)).await;
488
- }
489
- };
490
- pin!(audio_stream);
491
-
492
- let recv_fut = async {
493
- while let Some(voice_slice) = audio_stream.next().await {
494
- client_stream.send(Message::Binary(voice_slice)).await?;
495
- }
496
- info!("sent all voice chunks");
497
-
498
- while let Some(next_msg) = client_stream.next().await {
499
- debug!(?next_msg);
500
- let msg = next_msg?;
501
- let Message::Text(json_str) = msg else { continue };
502
- let Ok(evt) = serde_json::from_str::<SingleEvent>(&json_str) else { continue };
503
- if let SingleEvent::Voice { .. } = evt {
504
- return Ok(())
505
- }
506
- }
507
-
508
- Ok(()) as anyhow::Result<()>
509
- };
510
-
511
- select! {
512
- res = recv_fut => {
513
- if let Err(e) = res {
514
- error!("Error: {:?}", e);
515
- assert!(false, "Error: {}", e);
516
- }
517
- }
518
- _ = sleep(Duration::from_secs(10)) => {
519
- assert!(false, "timeout");
520
- }
521
- };
522
-
523
- handle.abort();
524
- }
525
- }
 
3
  * SPDX-License-Identifier: Apache-2.0.
4
  */
5
 
 
 
 
 
 
6
  use aws_config::BehaviorVersion;
7
  use aws_sdk_transcribestreaming::meta::PKG_VERSION;
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  use tokio::select;
9
+ use tracing::{debug};
10
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
11
 
12
+ use polyhedron::app;
 
13
 
 
 
 
 
 
 
 
 
 
14
 
15
  #[tokio::main]
16
  async fn main() -> Result<(), std::io::Error> {
 
22
  debug!("Transcribe client version: {}", PKG_VERSION);
23
 
24
  let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  select! {
28
+ res = app(&shared_config) => res,
29
  _ = tokio::signal::ctrl_c() => {
30
  tracing::info!("Shutting down");
31
  Ok(())
32
  },
33
  }
34
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/stream_single.rs ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ use aws_config::BehaviorVersion;
3
+ use std::time::Duration;
4
+ use async_stream::stream;
5
+ use poem::{
6
+ get,
7
+ listener::{Listener, Acceptor, TcpListener},
8
+ EndpointExt, Route, Server,
9
+ };
10
+ use tokio::{pin, select};
11
+ use tokio::time::sleep;
12
+ use tokio_stream::StreamExt;
13
+ use tokio_tungstenite::{
14
+ connect_async,
15
+ tungstenite::Message,
16
+ };
17
+ use futures_util::sink::SinkExt;
18
+ use tracing::{info, error, debug};
19
+ use polyhedron::{
20
+ stream_single,
21
+ SingleEvent,
22
+ asr::slice_i16_to_u8,
23
+ Context
24
+ };
25
+
26
+ #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
27
+ #[tracing_test::traced_test]
28
+ async fn test_single() {
29
+ let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
30
+ let ctx = Context::new(&shared_config);
31
+
32
+ let acceptor = TcpListener::bind("[::]:0")
33
+ .into_acceptor()
34
+ .await
35
+ .unwrap();
36
+ let addr = acceptor
37
+ .local_addr()
38
+ .remove(0)
39
+ .as_socket_addr()
40
+ .cloned()
41
+ .unwrap();
42
+ let server = Server::new_with_acceptor(acceptor);
43
+ let handle = tokio::spawn(async move {
44
+ let _ = server.run(
45
+ Route::new()
46
+ .at("/ws/voice", get(stream_single))
47
+ .data(ctx)
48
+ ).await;
49
+ });
50
+
51
+ let url = format!(
52
+ "ws://{}/ws/voice?id=123abc&from=zh-CN&to=en-US&voice=Amy", addr
53
+ );
54
+ let (mut client_stream, _) = connect_async(url)
55
+ .await
56
+ .unwrap();
57
+
58
+ client_stream
59
+ .send(Message::Binary(Vec::new()))
60
+ .await
61
+ .unwrap();
62
+
63
+
64
+ let wav = hound::WavReader::open("whisper/samples/samples_jfk.wav")
65
+ .expect("failed to open wav");
66
+ let spec = wav.spec();
67
+ info!("{:?}", spec);
68
+ let samples = wav
69
+ .into_samples::<i16>()
70
+ .map(|s| s.unwrap())
71
+ .collect::<Vec<i16>>();
72
+ let chunks = samples.chunks(1600)
73
+ .map(|chunk| chunk.to_vec())
74
+ .into_iter();
75
+
76
+ let audio_stream = stream! {
77
+ for chunk in chunks {
78
+ yield slice_i16_to_u8(&chunk);
79
+ sleep(Duration::from_millis(10)).await;
80
+ }
81
+ };
82
+ pin!(audio_stream);
83
+
84
+ let recv_fut = async {
85
+ while let Some(voice_slice) = audio_stream.next().await {
86
+ client_stream.send(Message::Binary(voice_slice)).await?;
87
+ }
88
+ info!("sent all voice chunks");
89
+
90
+ while let Some(Ok(msg)) = client_stream.next().await {
91
+ debug!("recv: {:?}", msg);
92
+ let Message::Text(json_str) = msg else { continue };
93
+ let Ok(evt) = serde_json::from_str::<SingleEvent>(&json_str) else { continue };
94
+ if let SingleEvent::Voice { .. } = evt {
95
+ return Ok(())
96
+ }
97
+ }
98
+
99
+ Ok(()) as anyhow::Result<()>
100
+ };
101
+
102
+ select! {
103
+ res = recv_fut => {
104
+ if let Err(e) = res {
105
+ error!("Error: {:?}", e);
106
+ assert!(false, "Error: {}", e);
107
+ }
108
+ }
109
+ _ = sleep(Duration::from_secs(10)) => {
110
+ assert!(false, "timeout");
111
+ }
112
+ };
113
+
114
+ handle.abort();
115
+ }
whisper/src/handler.rs CHANGED
@@ -445,7 +445,6 @@ mod test {
445
 
446
  let mut rx = whisper_handler.subscribe();
447
  let send_fut = async {
448
- // tokio::time::sleep(Duration::from_secs(5)).await;
449
  for chunk in chunks {
450
  let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample");
451
  tokio::time::sleep(Duration::from_millis(100)).await;
 
445
 
446
  let mut rx = whisper_handler.subscribe();
447
  let send_fut = async {
 
448
  for chunk in chunks {
449
  let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample");
450
  tokio::time::sleep(Duration::from_millis(100)).await;