mingyang91 commited on
Commit
786313f
·
verified ·
1 Parent(s): f33df9c

add single channel

Browse files
Files changed (4) hide show
  1. Cargo.lock +1 -0
  2. Cargo.toml +1 -0
  3. src/base64box.rs +36 -0
  4. src/main.rs +121 -0
Cargo.lock CHANGED
@@ -1546,6 +1546,7 @@ dependencies = [
1546
  "aws-sdk-polly",
1547
  "aws-sdk-transcribestreaming",
1548
  "aws-sdk-translate",
 
1549
  "config",
1550
  "futures-util",
1551
  "lazy_static",
 
1546
  "aws-sdk-polly",
1547
  "aws-sdk-transcribestreaming",
1548
  "aws-sdk-translate",
1549
+ "base64 0.21.5",
1550
  "config",
1551
  "futures-util",
1552
  "lazy_static",
Cargo.toml CHANGED
@@ -22,6 +22,7 @@ tracing = { version = "0.1", features = [] }
22
  tracing-subscriber = { version = "0.3", features = ["env-filter"] }
23
  async-trait = "0.1.74"
24
  lazy_static = "1.4.0"
 
25
 
26
  [features]
27
  whisper = ["dep:whisper"]
 
22
  tracing-subscriber = { version = "0.3", features = ["env-filter"] }
23
  async-trait = "0.1.74"
24
  lazy_static = "1.4.0"
25
+ base64 = { version = "0.21.5", features = [] }
26
 
27
  [features]
28
  whisper = ["dep:whisper"]
src/base64box.rs ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use serde::{Deserialize, Deserializer, Serialize, Serializer};
2
+ use serde::de::Error;
3
+ use base64::{Engine as _, alphabet, engine::{GeneralPurpose, general_purpose}};
4
+ use lazy_static::lazy_static;
5
+
6
+ lazy_static! {
7
+ static ref ENGINE: GeneralPurpose = GeneralPurpose::new(
8
+ &alphabet::STANDARD,
9
+ general_purpose::NO_PAD
10
+ );
11
+ }
12
+ #[derive(Debug)]
13
+ pub struct Base64Box(pub Vec<u8>);
14
+ impl Serialize for Base64Box {
15
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
16
+ serializer.collect_str(&ENGINE.encode(&self.0))
17
+ }
18
+ }
19
+
20
+ impl<'de> Deserialize<'de> for Base64Box {
21
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
22
+ struct Vis;
23
+ impl serde::de::Visitor<'_> for Vis {
24
+ type Value = Base64Box;
25
+
26
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
27
+ formatter.write_str("a base64 string")
28
+ }
29
+
30
+ fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
31
+ ENGINE.decode(v).map(Base64Box).map_err(Error::custom)
32
+ }
33
+ }
34
+ deserializer.deserialize_str(Vis)
35
+ }
36
+ }
src/main.rs CHANGED
@@ -28,10 +28,12 @@ use tracing::{debug, span};
28
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
29
 
30
  use crate::{config::*, lesson::*};
 
31
 
32
  mod config;
33
  mod lesson;
34
  mod asr;
 
35
 
36
  #[derive(Clone)]
37
  struct Context {
@@ -63,6 +65,7 @@ async fn main() -> Result<(), std::io::Error> {
63
  .at("/ws/teacher", get(stream_speaker))
64
  .at("/ws/lesson-listener", get(stream_listener))
65
  .at("/ws/student", get(stream_listener))
 
66
  .at(
67
  "lesson-speaker",
68
  StaticFileEndpoint::new("./static/index.html"),
@@ -273,6 +276,124 @@ async fn stream_listener(
273
  })
274
  }
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  fn u8_to_i16(input: &[u8]) -> Vec<i16> {
277
  input
278
  .chunks_exact(2)
 
28
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
29
 
30
  use crate::{config::*, lesson::*};
31
+ use crate::base64box::Base64Box;
32
 
33
  mod config;
34
  mod lesson;
35
  mod asr;
36
+ mod base64box;
37
 
38
  #[derive(Clone)]
39
  struct Context {
 
65
  .at("/ws/teacher", get(stream_speaker))
66
  .at("/ws/lesson-listener", get(stream_listener))
67
  .at("/ws/student", get(stream_listener))
68
+ .at("/ws/voice", get(stream_single))
69
  .at(
70
  "lesson-speaker",
71
  StaticFileEndpoint::new("./static/index.html"),
 
276
  })
277
  }
278
 
279
+ #[derive(Serialize, Debug)]
280
+ enum SingleEvent {
281
+ #[serde(rename = "original")]
282
+ Transcription {
283
+ content: String,
284
+ #[serde(rename = "isFinal")]
285
+ is_final: bool
286
+ },
287
+ #[serde(rename = "translated")]
288
+ Translation { content: String },
289
+ #[serde(rename = "voice")]
290
+ Voice {
291
+ content: Base64Box
292
+ },
293
+ }
294
+
295
+
296
+ #[derive(Deserialize, Debug)]
297
+ pub struct SingleQuery {
298
+ id: u32,
299
+ from: String,
300
+ to: String,
301
+ voice: Option<String>,
302
+ }
303
+
304
+
305
+ #[handler]
306
+ async fn stream_single(
307
+ ctx: Data<&Context>,
308
+ query: Query<SingleQuery>,
309
+ ws: WebSocket
310
+ ) -> impl IntoResponse {
311
+ let lessons_manager = ctx.lessons_manager.clone();
312
+ ws.on_upgrade(|mut socket| async move {
313
+ let Ok(lang) = query.from.parse::<LanguageCode>() else {
314
+ let _ = socket
315
+ .send(Message::Text(format!("invalid language code: {}", query.from)))
316
+ .await;
317
+ return
318
+ };
319
+ let lesson = lessons_manager
320
+ .create_lesson(
321
+ query.id,
322
+ AsrEngine::AWS,
323
+ lang,
324
+ )
325
+ .await;
326
+
327
+ let mut transcribe_rx = lesson.transcript_channel();
328
+ let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
329
+ let mut translate_rx = lang_lesson.translated_channel();
330
+ let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
331
+ let _ = socket
332
+ .send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
333
+ .await;
334
+ return
335
+ };
336
+ let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
337
+ let mut voice_rx = voice_lesson.voice_channel();
338
+ // let mut lip_sync_rx = voice_lesson.lip_sync_channel();
339
+
340
+ let fut = async {
341
+ loop {
342
+ let evt = select! {
343
+ input = socket.next() => {
344
+ let Some(res) = input else { break };
345
+ let msg = res?;
346
+ if msg.is_close() {
347
+ break
348
+ }
349
+ let Message::Binary(bin) = msg else {
350
+ tracing::warn!("Other: {:?}", msg);
351
+ continue
352
+ };
353
+ let frame = u8_to_i16(&bin);
354
+ lesson.send(frame).await?;
355
+ continue
356
+ },
357
+ transcript_poll = transcribe_rx.recv() => {
358
+ let evt = transcript_poll?;
359
+ if evt.is_final {
360
+ tracing::trace!("Transcribed: {}", evt.transcript);
361
+ }
362
+ SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
363
+ },
364
+ translated_poll = translate_rx.recv() => {
365
+ let translated = translated_poll?;
366
+ SingleEvent::Translation { content: translated }
367
+ },
368
+ voice_poll = voice_rx.recv() => {
369
+ let voice = voice_poll?;
370
+ SingleEvent::Voice { content: Base64Box(voice) }
371
+ },
372
+ };
373
+
374
+ let Ok(json) = serde_json::to_string(&evt) else {
375
+ tracing::warn!("failed to serialize json: {:?}", evt);
376
+ continue
377
+ };
378
+ socket.send(Message::Text(json)).await?
379
+ }
380
+ Ok(())
381
+ };
382
+
383
+ let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
384
+ let _ = span.enter();
385
+ let res: anyhow::Result<()> = fut.await;
386
+ match res {
387
+ Ok(()) => {
388
+ tracing::info!("lesson speaker closed");
389
+ }
390
+ Err(e) => {
391
+ tracing::warn!("lesson speaker error: {}", e);
392
+ }
393
+ }
394
+ })
395
+ }
396
+
397
  fn u8_to_i16(input: &[u8]) -> Vec<i16> {
398
  input
399
  .chunks_exact(2)