Spaces:
Sleeping
Sleeping
re-org code structure
Browse files- Cargo.lock +12 -12
- Cargo.toml +8 -0
- src/asr/aws.rs +2 -0
- src/asr/mod.rs +1 -2
- src/lesson.rs +8 -13
- src/lib.rs +400 -0
- src/main.rs +3 -494
- tests/stream_single.rs +115 -0
- whisper/src/handler.rs +0 -1
Cargo.lock
CHANGED
@@ -696,9 +696,9 @@ dependencies = [
|
|
696 |
|
697 |
[[package]]
|
698 |
name = "config"
|
699 |
-
version = "0.13.
|
700 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
701 |
-
checksum = "
|
702 |
dependencies = [
|
703 |
"async-trait",
|
704 |
"json5",
|
@@ -838,9 +838,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
|
838 |
|
839 |
[[package]]
|
840 |
name = "form_urlencoded"
|
841 |
-
version = "1.2.
|
842 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
843 |
-
checksum = "
|
844 |
dependencies = [
|
845 |
"percent-encoding",
|
846 |
]
|
@@ -853,9 +853,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
|
853 |
|
854 |
[[package]]
|
855 |
name = "futures"
|
856 |
-
version = "0.3.
|
857 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
858 |
-
checksum = "
|
859 |
dependencies = [
|
860 |
"futures-channel",
|
861 |
"futures-core",
|
@@ -1179,9 +1179,9 @@ dependencies = [
|
|
1179 |
|
1180 |
[[package]]
|
1181 |
name = "idna"
|
1182 |
-
version = "0.
|
1183 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1184 |
-
checksum = "
|
1185 |
dependencies = [
|
1186 |
"unicode-bidi",
|
1187 |
"unicode-normalization",
|
@@ -1487,9 +1487,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
|
1487 |
|
1488 |
[[package]]
|
1489 |
name = "percent-encoding"
|
1490 |
-
version = "2.3.
|
1491 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1492 |
-
checksum = "
|
1493 |
|
1494 |
[[package]]
|
1495 |
name = "pest"
|
@@ -2535,9 +2535,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
|
2535 |
|
2536 |
[[package]]
|
2537 |
name = "url"
|
2538 |
-
version = "2.
|
2539 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2540 |
-
checksum = "
|
2541 |
dependencies = [
|
2542 |
"form_urlencoded",
|
2543 |
"idna",
|
|
|
696 |
|
697 |
[[package]]
|
698 |
name = "config"
|
699 |
+
version = "0.13.4"
|
700 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
701 |
+
checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca"
|
702 |
dependencies = [
|
703 |
"async-trait",
|
704 |
"json5",
|
|
|
838 |
|
839 |
[[package]]
|
840 |
name = "form_urlencoded"
|
841 |
+
version = "1.2.1"
|
842 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
843 |
+
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
|
844 |
dependencies = [
|
845 |
"percent-encoding",
|
846 |
]
|
|
|
853 |
|
854 |
[[package]]
|
855 |
name = "futures"
|
856 |
+
version = "0.3.29"
|
857 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
858 |
+
checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335"
|
859 |
dependencies = [
|
860 |
"futures-channel",
|
861 |
"futures-core",
|
|
|
1179 |
|
1180 |
[[package]]
|
1181 |
name = "idna"
|
1182 |
+
version = "0.5.0"
|
1183 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1184 |
+
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
|
1185 |
dependencies = [
|
1186 |
"unicode-bidi",
|
1187 |
"unicode-normalization",
|
|
|
1487 |
|
1488 |
[[package]]
|
1489 |
name = "percent-encoding"
|
1490 |
+
version = "2.3.1"
|
1491 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1492 |
+
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
1493 |
|
1494 |
[[package]]
|
1495 |
name = "pest"
|
|
|
2535 |
|
2536 |
[[package]]
|
2537 |
name = "url"
|
2538 |
+
version = "2.5.0"
|
2539 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2540 |
+
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
|
2541 |
dependencies = [
|
2542 |
"form_urlencoded",
|
2543 |
"idna",
|
Cargo.toml
CHANGED
@@ -3,6 +3,14 @@ name = "polyhedron"
|
|
3 |
version = "0.1.0"
|
4 |
edition = "2021"
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
[dependencies]
|
7 |
anyhow = "1.0"
|
8 |
async-stream = "0.3"
|
|
|
3 |
version = "0.1.0"
|
4 |
edition = "2021"
|
5 |
|
6 |
+
[lib]
|
7 |
+
name = "polyhedron"
|
8 |
+
path = "src/lib.rs"
|
9 |
+
|
10 |
+
[[bin]]
|
11 |
+
name = "polyhedron"
|
12 |
+
path = "src/main.rs"
|
13 |
+
|
14 |
[dependencies]
|
15 |
anyhow = "1.0"
|
16 |
async-stream = "0.3"
|
src/asr/aws.rs
CHANGED
@@ -12,6 +12,7 @@ use tokio::select;
|
|
12 |
use tokio::sync::broadcast::Receiver;
|
13 |
use tokio_stream::Stream;
|
14 |
use futures_util::TryStreamExt;
|
|
|
15 |
use crate::asr::{ASR, Event, slice_i16_to_u8};
|
16 |
|
17 |
pub struct AwsAsr {
|
@@ -58,6 +59,7 @@ impl AwsAsr {
|
|
58 |
let output_stream = to_stream(output);
|
59 |
output_stream
|
60 |
.try_for_each(|text| async {
|
|
|
61 |
let _ = shared_speaker_transcript.send(text);
|
62 |
Ok(())
|
63 |
})
|
|
|
12 |
use tokio::sync::broadcast::Receiver;
|
13 |
use tokio_stream::Stream;
|
14 |
use futures_util::TryStreamExt;
|
15 |
+
use tracing::trace;
|
16 |
use crate::asr::{ASR, Event, slice_i16_to_u8};
|
17 |
|
18 |
pub struct AwsAsr {
|
|
|
59 |
let output_stream = to_stream(output);
|
60 |
output_stream
|
61 |
.try_for_each(|text| async {
|
62 |
+
trace!("here");
|
63 |
let _ = shared_speaker_transcript.send(text);
|
64 |
Ok(())
|
65 |
})
|
src/asr/mod.rs
CHANGED
@@ -18,8 +18,7 @@ pub(crate) trait ASR {
|
|
18 |
}
|
19 |
|
20 |
|
21 |
-
|
22 |
-
pub(crate) fn slice_i16_to_u8(slice: &[i16]) -> Vec<u8> {
|
23 |
slice
|
24 |
.iter()
|
25 |
.flat_map(|&sample| {
|
|
|
18 |
}
|
19 |
|
20 |
|
21 |
+
pub fn slice_i16_to_u8(slice: &[i16]) -> Vec<u8> {
|
|
|
22 |
slice
|
23 |
.iter()
|
24 |
.flat_map(|&sample| {
|
src/lesson.rs
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
use aws_config::SdkConfig;
|
2 |
-
use aws_sdk_polly::primitives::ByteStream;
|
3 |
use aws_sdk_polly::types::{Engine, OutputFormat, SpeechMarkType, VoiceId};
|
4 |
use aws_sdk_transcribestreaming::types::{LanguageCode};
|
5 |
use futures_util::future::try_join;
|
6 |
-
use serde::{Deserialize, Serialize};
|
7 |
use std::collections::BTreeMap;
|
8 |
use std::fmt::{Debug, Formatter};
|
9 |
use std::io::BufRead;
|
@@ -17,11 +15,14 @@ use crate::asr::{Event, aws::AwsAsr, ASR};
|
|
17 |
|
18 |
#[cfg(feature = "whisper")]
|
19 |
use crate::asr::whisper::WhisperAsr;
|
|
|
|
|
|
|
20 |
|
21 |
pub struct InnerLessonsManager {
|
22 |
translate_client: aws_sdk_translate::Client,
|
23 |
polly_client: aws_sdk_polly::Client,
|
24 |
-
lessons: Arc<RwLock<BTreeMap<
|
25 |
}
|
26 |
|
27 |
#[derive(Clone)]
|
@@ -43,7 +44,7 @@ impl Deref for LessonsManager {
|
|
43 |
}
|
44 |
}
|
45 |
|
46 |
-
pub
|
47 |
AWS,
|
48 |
#[allow(dead_code)]
|
49 |
#[cfg(feature = "whisper")]
|
@@ -69,7 +70,7 @@ impl AsrEngine {
|
|
69 |
}
|
70 |
|
71 |
impl LessonsManager {
|
72 |
-
pub
|
73 |
let translate_client = aws_sdk_translate::Client::new(sdk_config);
|
74 |
let polly_client = aws_sdk_polly::Client::new(sdk_config);
|
75 |
let inner = InnerLessonsManager {
|
@@ -80,14 +81,14 @@ impl LessonsManager {
|
|
80 |
LessonsManager { inner: Arc::new(inner) }
|
81 |
}
|
82 |
|
83 |
-
pub(crate) async fn create_lesson(&self, id:
|
84 |
let mut map = self.lessons.write().await;
|
85 |
let lesson: Lesson = InnerLesson::new(self.clone(), engine, speaker_lang).await.into();
|
86 |
map.insert(id, lesson.clone());
|
87 |
lesson
|
88 |
}
|
89 |
|
90 |
-
pub(crate) async fn get_lesson(&self, id:
|
91 |
let map = self.lessons.read().await;
|
92 |
map.get(&id).cloned()
|
93 |
}
|
@@ -434,12 +435,6 @@ impl Drop for InnerVoiceLesson {
|
|
434 |
}
|
435 |
|
436 |
|
437 |
-
// {"time":180,"type":"viseme","value":"r"}
|
438 |
-
#[derive(Debug, Deserialize, Clone, Serialize)]
|
439 |
-
pub(crate) struct Viseme {
|
440 |
-
time: u32,
|
441 |
-
value: String,
|
442 |
-
}
|
443 |
|
444 |
#[derive(Debug)]
|
445 |
enum SynthesizeError {
|
|
|
1 |
use aws_config::SdkConfig;
|
|
|
2 |
use aws_sdk_polly::types::{Engine, OutputFormat, SpeechMarkType, VoiceId};
|
3 |
use aws_sdk_transcribestreaming::types::{LanguageCode};
|
4 |
use futures_util::future::try_join;
|
|
|
5 |
use std::collections::BTreeMap;
|
6 |
use std::fmt::{Debug, Formatter};
|
7 |
use std::io::BufRead;
|
|
|
15 |
|
16 |
#[cfg(feature = "whisper")]
|
17 |
use crate::asr::whisper::WhisperAsr;
|
18 |
+
use crate::Viseme;
|
19 |
+
|
20 |
+
pub type LessonID = String;
|
21 |
|
22 |
pub struct InnerLessonsManager {
|
23 |
translate_client: aws_sdk_translate::Client,
|
24 |
polly_client: aws_sdk_polly::Client,
|
25 |
+
lessons: Arc<RwLock<BTreeMap<LessonID, Lesson>>>,
|
26 |
}
|
27 |
|
28 |
#[derive(Clone)]
|
|
|
44 |
}
|
45 |
}
|
46 |
|
47 |
+
pub enum AsrEngine {
|
48 |
AWS,
|
49 |
#[allow(dead_code)]
|
50 |
#[cfg(feature = "whisper")]
|
|
|
70 |
}
|
71 |
|
72 |
impl LessonsManager {
|
73 |
+
pub fn new(sdk_config: &SdkConfig) -> Self {
|
74 |
let translate_client = aws_sdk_translate::Client::new(sdk_config);
|
75 |
let polly_client = aws_sdk_polly::Client::new(sdk_config);
|
76 |
let inner = InnerLessonsManager {
|
|
|
81 |
LessonsManager { inner: Arc::new(inner) }
|
82 |
}
|
83 |
|
84 |
+
pub(crate) async fn create_lesson(&self, id: LessonID, engine: AsrEngine, speaker_lang: LanguageCode) -> Lesson {
|
85 |
let mut map = self.lessons.write().await;
|
86 |
let lesson: Lesson = InnerLesson::new(self.clone(), engine, speaker_lang).await.into();
|
87 |
map.insert(id, lesson.clone());
|
88 |
lesson
|
89 |
}
|
90 |
|
91 |
+
pub(crate) async fn get_lesson(&self, id: LessonID) -> Option<Lesson> {
|
92 |
let map = self.lessons.read().await;
|
93 |
map.get(&id).cloned()
|
94 |
}
|
|
|
435 |
}
|
436 |
|
437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
#[derive(Debug)]
|
440 |
enum SynthesizeError {
|
src/lib.rs
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[cfg(feature = "whisper")]
|
2 |
+
extern crate whisper;
|
3 |
+
|
4 |
+
pub mod config;
|
5 |
+
pub mod lesson;
|
6 |
+
pub mod asr;
|
7 |
+
pub mod base64box;
|
8 |
+
|
9 |
+
use aws_config::{SdkConfig};
|
10 |
+
use aws_sdk_transcribestreaming::types::LanguageCode;
|
11 |
+
use futures_util::{stream::StreamExt, SinkExt};
|
12 |
+
use poem::{
|
13 |
+
endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
|
14 |
+
get, handler,
|
15 |
+
listener::TcpListener,
|
16 |
+
web::{
|
17 |
+
websocket::{Message, WebSocket},
|
18 |
+
Data, Query,
|
19 |
+
},
|
20 |
+
EndpointExt, IntoResponse, Route, Server,
|
21 |
+
};
|
22 |
+
use serde::{Deserialize, Serialize};
|
23 |
+
use tokio::select;
|
24 |
+
use tracing::{debug, span};
|
25 |
+
|
26 |
+
use crate::base64box::Base64Box;
|
27 |
+
use crate::config::SETTINGS;
|
28 |
+
use crate::lesson::{AsrEngine, LessonID, LessonsManager};
|
29 |
+
|
30 |
+
#[derive(Clone)]
|
31 |
+
pub struct Context {
|
32 |
+
lessons_manager: LessonsManager,
|
33 |
+
}
|
34 |
+
|
35 |
+
impl Context {
|
36 |
+
pub fn new(config: &SdkConfig) -> Self {
|
37 |
+
Self {
|
38 |
+
lessons_manager: LessonsManager::new(config),
|
39 |
+
}
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
#[derive(Serialize, Deserialize, Debug)]
|
44 |
+
pub enum SingleEvent {
|
45 |
+
#[serde(rename = "original")]
|
46 |
+
Transcription {
|
47 |
+
content: String,
|
48 |
+
#[serde(rename = "isFinal")]
|
49 |
+
is_final: bool
|
50 |
+
},
|
51 |
+
#[serde(rename = "translated")]
|
52 |
+
Translation { content: String },
|
53 |
+
#[serde(rename = "voice")]
|
54 |
+
Voice {
|
55 |
+
content: Base64Box
|
56 |
+
},
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
#[derive(Deserialize, Debug)]
|
61 |
+
pub struct LessonListenerQuery {
|
62 |
+
id: String,
|
63 |
+
language: String,
|
64 |
+
voice: String,
|
65 |
+
}
|
66 |
+
|
67 |
+
#[derive(Serialize, Debug)]
|
68 |
+
#[serde(tag = "type")]
|
69 |
+
pub enum LiveLessonTextEvent {
|
70 |
+
#[serde(rename = "original")]
|
71 |
+
Transcription {
|
72 |
+
content: String,
|
73 |
+
#[serde(rename = "isFinal")]
|
74 |
+
is_final: bool
|
75 |
+
},
|
76 |
+
Translation { content: String },
|
77 |
+
LipSync { visemes: Vec<Viseme> },
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
// {"time":180,"type":"viseme","value":"r"}
|
82 |
+
#[derive(Debug, Deserialize, Clone, Serialize)]
|
83 |
+
pub struct Viseme {
|
84 |
+
time: u32,
|
85 |
+
value: String,
|
86 |
+
}
|
87 |
+
|
88 |
+
#[derive(Deserialize, Debug)]
|
89 |
+
pub struct LessonSpeakerQuery {
|
90 |
+
id: LessonID,
|
91 |
+
language: String,
|
92 |
+
#[allow(dead_code)] // TODO: use this in the future
|
93 |
+
prompt: Option<String>,
|
94 |
+
}
|
95 |
+
|
96 |
+
#[handler]
|
97 |
+
async fn stream_speaker(
|
98 |
+
ctx: Data<&Context>,
|
99 |
+
query: Query<LessonSpeakerQuery>,
|
100 |
+
ws: WebSocket,
|
101 |
+
) -> impl IntoResponse {
|
102 |
+
let lessons_manager = ctx.lessons_manager.clone();
|
103 |
+
ws.on_upgrade(|mut socket| async move {
|
104 |
+
let Ok(lang) = query.language.parse::<LanguageCode>() else {
|
105 |
+
let _ = socket
|
106 |
+
.send(Message::Text(format!("invalid language code: {}", query.language)))
|
107 |
+
.await;
|
108 |
+
return
|
109 |
+
};
|
110 |
+
let lesson = lessons_manager
|
111 |
+
.create_lesson(
|
112 |
+
query.id.clone(),
|
113 |
+
AsrEngine::AWS,
|
114 |
+
lang,
|
115 |
+
)
|
116 |
+
.await;
|
117 |
+
|
118 |
+
let mut transcribe_rx = lesson.transcript_channel();
|
119 |
+
let fut = async {
|
120 |
+
loop {
|
121 |
+
select! {
|
122 |
+
msg = socket.next() => {
|
123 |
+
let Some(res) = msg else { break };
|
124 |
+
let msg = res?;
|
125 |
+
if msg.is_close() {
|
126 |
+
break
|
127 |
+
}
|
128 |
+
let Message::Binary(bin) = msg else {
|
129 |
+
tracing::warn!("Other: {:?}", msg);
|
130 |
+
continue
|
131 |
+
};
|
132 |
+
let frame = u8_to_i16(&bin);
|
133 |
+
lesson.send(frame).await?
|
134 |
+
},
|
135 |
+
output = transcribe_rx.recv() => {
|
136 |
+
let evt = output?;
|
137 |
+
if evt.is_final {
|
138 |
+
tracing::trace!("Transcribed: {}", evt.transcript);
|
139 |
+
}
|
140 |
+
let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
|
141 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
142 |
+
tracing::warn!("failed to serialize json: {:?}", evt);
|
143 |
+
continue
|
144 |
+
};
|
145 |
+
socket.send(Message::Text(json)).await?
|
146 |
+
},
|
147 |
+
}
|
148 |
+
}
|
149 |
+
Ok(())
|
150 |
+
};
|
151 |
+
|
152 |
+
let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
|
153 |
+
let _ = span.enter();
|
154 |
+
let res: anyhow::Result<()> = fut.await;
|
155 |
+
match res {
|
156 |
+
Ok(()) => {
|
157 |
+
tracing::info!("lesson speaker closed");
|
158 |
+
}
|
159 |
+
Err(e) => {
|
160 |
+
tracing::warn!("lesson speaker error: {}", e);
|
161 |
+
}
|
162 |
+
}
|
163 |
+
})
|
164 |
+
}
|
165 |
+
|
166 |
+
#[handler]
|
167 |
+
async fn stream_listener(
|
168 |
+
ctx: Data<&Context>,
|
169 |
+
query: Query<LessonListenerQuery>,
|
170 |
+
ws: WebSocket,
|
171 |
+
) -> impl IntoResponse {
|
172 |
+
let lessons_manager = ctx.lessons_manager.clone();
|
173 |
+
|
174 |
+
ws.on_upgrade(|mut socket| async move {
|
175 |
+
let voice_id = match query.voice.parse() {
|
176 |
+
Ok(id) => id,
|
177 |
+
Err(e) => {
|
178 |
+
let _ = socket
|
179 |
+
.send(Message::Text(format!("invalid voice id: {}", e)))
|
180 |
+
.await;
|
181 |
+
return;
|
182 |
+
}
|
183 |
+
};
|
184 |
+
|
185 |
+
let lesson_opt = lessons_manager.get_lesson(query.id.clone()).await;
|
186 |
+
debug!("listener param = {:?}", query);
|
187 |
+
let Some(lesson) = lesson_opt else {
|
188 |
+
let _ = socket
|
189 |
+
.send(Message::Text("lesson not found".to_string()))
|
190 |
+
.await;
|
191 |
+
return;
|
192 |
+
};
|
193 |
+
let mut transcript_rx = lesson.transcript_channel();
|
194 |
+
let mut lang_lesson = lesson.get_or_init(query.language.clone()).await;
|
195 |
+
let mut translate_rx = lang_lesson.translated_channel();
|
196 |
+
let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
|
197 |
+
let mut voice_rx = voice_lesson.voice_channel();
|
198 |
+
let mut lip_sync_rx = voice_lesson.lip_sync_channel();
|
199 |
+
|
200 |
+
let fut = async {
|
201 |
+
loop {
|
202 |
+
select! {
|
203 |
+
transcript_poll = transcript_rx.recv() => {
|
204 |
+
let transcript = transcript_poll?;
|
205 |
+
let evt = LiveLessonTextEvent::Transcription {
|
206 |
+
content: transcript.transcript,
|
207 |
+
is_final: transcript.is_final
|
208 |
+
};
|
209 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
210 |
+
tracing::warn!("failed to serialize: {:?}", evt);
|
211 |
+
continue
|
212 |
+
};
|
213 |
+
tracing::debug!("Transcribed: {}", json);
|
214 |
+
socket.send(Message::Text(json)).await?
|
215 |
+
},
|
216 |
+
translated_poll = translate_rx.recv() => {
|
217 |
+
let translated = translated_poll?;
|
218 |
+
let evt = LiveLessonTextEvent::Translation { content: translated };
|
219 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
220 |
+
tracing::warn!("failed to serialize: {:?}", evt);
|
221 |
+
continue
|
222 |
+
};
|
223 |
+
tracing::debug!("Translated: {}", json);
|
224 |
+
socket.send(Message::Text(json)).await?
|
225 |
+
},
|
226 |
+
voice_poll = voice_rx.recv() => {
|
227 |
+
let voice = voice_poll?;
|
228 |
+
socket.send(Message::Binary(voice)).await?
|
229 |
+
},
|
230 |
+
visemes_poll = lip_sync_rx.recv() => {
|
231 |
+
let visemes = visemes_poll?;
|
232 |
+
let evt = LiveLessonTextEvent::LipSync { visemes };
|
233 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
234 |
+
tracing::warn!("failed to serialize: {:?}", evt);
|
235 |
+
continue
|
236 |
+
};
|
237 |
+
socket.send(Message::Text(json)).await?
|
238 |
+
},
|
239 |
+
}
|
240 |
+
}
|
241 |
+
};
|
242 |
+
|
243 |
+
let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
|
244 |
+
let _ = span.enter();
|
245 |
+
let res: anyhow::Result<()> = fut.await;
|
246 |
+
match res {
|
247 |
+
Ok(()) => {
|
248 |
+
tracing::info!("lesson listener closed");
|
249 |
+
}
|
250 |
+
Err(e) => {
|
251 |
+
tracing::warn!("lesson listener error: {}", e);
|
252 |
+
}
|
253 |
+
}
|
254 |
+
})
|
255 |
+
}
|
256 |
+
|
257 |
+
#[derive(Deserialize, Debug)]
|
258 |
+
pub struct SingleQuery {
|
259 |
+
id: LessonID,
|
260 |
+
from: String,
|
261 |
+
to: String,
|
262 |
+
voice: Option<String>,
|
263 |
+
}
|
264 |
+
|
265 |
+
|
266 |
+
#[handler]
|
267 |
+
pub async fn stream_single(
|
268 |
+
ctx: Data<&Context>,
|
269 |
+
query: Query<SingleQuery>,
|
270 |
+
ws: WebSocket
|
271 |
+
) -> impl IntoResponse {
|
272 |
+
let lessons_manager = ctx.lessons_manager.clone();
|
273 |
+
ws.on_upgrade(|mut socket| async move {
|
274 |
+
let Ok(lang) = query.from.parse::<LanguageCode>() else {
|
275 |
+
let _ = socket
|
276 |
+
.send(Message::Text(format!("invalid language code: {}", query.from)))
|
277 |
+
.await;
|
278 |
+
return
|
279 |
+
};
|
280 |
+
let lesson = lessons_manager
|
281 |
+
.create_lesson(
|
282 |
+
query.id.clone(),
|
283 |
+
AsrEngine::AWS,
|
284 |
+
lang,
|
285 |
+
)
|
286 |
+
.await;
|
287 |
+
|
288 |
+
let mut transcribe_rx = lesson.transcript_channel();
|
289 |
+
let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
|
290 |
+
let mut translate_rx = lang_lesson.translated_channel();
|
291 |
+
let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
|
292 |
+
let _ = socket
|
293 |
+
.send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
|
294 |
+
.await;
|
295 |
+
return
|
296 |
+
};
|
297 |
+
let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
|
298 |
+
let mut voice_rx = voice_lesson.voice_channel();
|
299 |
+
// let mut lip_sync_rx = voice_lesson.lip_sync_channel();
|
300 |
+
|
301 |
+
let fut = async {
|
302 |
+
loop {
|
303 |
+
let evt = select! {
|
304 |
+
input = socket.next() => {
|
305 |
+
let Some(res) = input else { break };
|
306 |
+
let msg = res?;
|
307 |
+
if msg.is_close() {
|
308 |
+
break
|
309 |
+
}
|
310 |
+
let Message::Binary(bin) = msg else {
|
311 |
+
tracing::warn!("Other: {:?}", msg);
|
312 |
+
continue
|
313 |
+
};
|
314 |
+
let frame = u8_to_i16(&bin);
|
315 |
+
lesson.send(frame).await?;
|
316 |
+
continue
|
317 |
+
},
|
318 |
+
transcript_poll = transcribe_rx.recv() => {
|
319 |
+
let evt = transcript_poll?;
|
320 |
+
if evt.is_final {
|
321 |
+
tracing::trace!("Transcribed: {}", evt.transcript);
|
322 |
+
}
|
323 |
+
SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
|
324 |
+
},
|
325 |
+
translated_poll = translate_rx.recv() => {
|
326 |
+
let translated = translated_poll?;
|
327 |
+
SingleEvent::Translation { content: translated }
|
328 |
+
},
|
329 |
+
voice_poll = voice_rx.recv() => {
|
330 |
+
let voice = voice_poll?;
|
331 |
+
SingleEvent::Voice { content: Base64Box(voice) }
|
332 |
+
},
|
333 |
+
};
|
334 |
+
|
335 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
336 |
+
tracing::warn!("failed to serialize json: {:?}", evt);
|
337 |
+
continue
|
338 |
+
};
|
339 |
+
socket.send(Message::Text(json)).await?
|
340 |
+
}
|
341 |
+
Ok(())
|
342 |
+
};
|
343 |
+
|
344 |
+
let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
|
345 |
+
let _ = span.enter();
|
346 |
+
let res: anyhow::Result<()> = fut.await;
|
347 |
+
match res {
|
348 |
+
Ok(()) => {
|
349 |
+
tracing::info!("lesson speaker closed");
|
350 |
+
}
|
351 |
+
Err(e) => {
|
352 |
+
tracing::warn!("lesson speaker error: {}", e);
|
353 |
+
}
|
354 |
+
}
|
355 |
+
})
|
356 |
+
}
|
357 |
+
|
358 |
+
fn u8_to_i16(input: &[u8]) -> Vec<i16> {
|
359 |
+
input
|
360 |
+
.chunks_exact(2)
|
361 |
+
.map(|chunk| {
|
362 |
+
let mut buf = [0u8; 2];
|
363 |
+
buf.copy_from_slice(chunk);
|
364 |
+
i16::from_le_bytes(buf)
|
365 |
+
})
|
366 |
+
.collect::<Vec<i16>>()
|
367 |
+
}
|
368 |
+
|
369 |
+
pub async fn app(config: &SdkConfig) -> Result<(), std::io::Error> {
|
370 |
+
let ctx = Context {
|
371 |
+
lessons_manager: LessonsManager::new(config),
|
372 |
+
};
|
373 |
+
|
374 |
+
let app = Route::new()
|
375 |
+
.nest(
|
376 |
+
"/",
|
377 |
+
StaticFilesEndpoint::new("./static")
|
378 |
+
.show_files_listing()
|
379 |
+
.index_file("index.html"),
|
380 |
+
)
|
381 |
+
.at("/ws/lesson-speaker", get(stream_speaker))
|
382 |
+
.at("/ws/teacher", get(stream_speaker))
|
383 |
+
.at("/ws/lesson-listener", get(stream_listener))
|
384 |
+
.at("/ws/student", get(stream_listener))
|
385 |
+
.at("/ws/voice", get(stream_single))
|
386 |
+
.at(
|
387 |
+
"lesson-speaker",
|
388 |
+
StaticFileEndpoint::new("./static/index.html"),
|
389 |
+
)
|
390 |
+
.at(
|
391 |
+
"lesson-listener",
|
392 |
+
StaticFileEndpoint::new("./static/index.html"),
|
393 |
+
)
|
394 |
+
.data(ctx);
|
395 |
+
let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
|
396 |
+
let listener = TcpListener::bind(addr);
|
397 |
+
let server = Server::new(listener);
|
398 |
+
|
399 |
+
server.run(app).await
|
400 |
+
}
|
src/main.rs
CHANGED
@@ -3,42 +3,14 @@
|
|
3 |
* SPDX-License-Identifier: Apache-2.0.
|
4 |
*/
|
5 |
|
6 |
-
#![allow(clippy::result_large_err)]
|
7 |
-
|
8 |
-
#[cfg(feature = "whisper")]
|
9 |
-
extern crate whisper;
|
10 |
-
|
11 |
use aws_config::BehaviorVersion;
|
12 |
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
|
13 |
-
use aws_sdk_transcribestreaming::types::LanguageCode;
|
14 |
-
use futures_util::{stream::StreamExt, SinkExt};
|
15 |
-
use poem::{
|
16 |
-
endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
|
17 |
-
get, handler,
|
18 |
-
listener::TcpListener,
|
19 |
-
web::{
|
20 |
-
websocket::{Message, WebSocket},
|
21 |
-
Data, Query,
|
22 |
-
},
|
23 |
-
EndpointExt, IntoResponse, Route, Server,
|
24 |
-
};
|
25 |
-
use serde::{Deserialize, Serialize};
|
26 |
use tokio::select;
|
27 |
-
use tracing::{debug
|
28 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
29 |
|
30 |
-
use
|
31 |
-
use crate::base64box::Base64Box;
|
32 |
|
33 |
-
mod config;
|
34 |
-
mod lesson;
|
35 |
-
mod asr;
|
36 |
-
mod base64box;
|
37 |
-
|
38 |
-
#[derive(Clone)]
|
39 |
-
struct Context {
|
40 |
-
lessons_manager: LessonsManager,
|
41 |
-
}
|
42 |
|
43 |
#[tokio::main]
|
44 |
async fn main() -> Result<(), std::io::Error> {
|
@@ -50,476 +22,13 @@ async fn main() -> Result<(), std::io::Error> {
|
|
50 |
debug!("Transcribe client version: {}", PKG_VERSION);
|
51 |
|
52 |
let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
|
53 |
-
let ctx = Context {
|
54 |
-
lessons_manager: LessonsManager::new(&shared_config),
|
55 |
-
};
|
56 |
|
57 |
-
let app = Route::new()
|
58 |
-
.nest(
|
59 |
-
"/",
|
60 |
-
StaticFilesEndpoint::new("./static")
|
61 |
-
.show_files_listing()
|
62 |
-
.index_file("index.html"),
|
63 |
-
)
|
64 |
-
.at("/ws/lesson-speaker", get(stream_speaker))
|
65 |
-
.at("/ws/teacher", get(stream_speaker))
|
66 |
-
.at("/ws/lesson-listener", get(stream_listener))
|
67 |
-
.at("/ws/student", get(stream_listener))
|
68 |
-
.at("/ws/voice", get(stream_single))
|
69 |
-
.at(
|
70 |
-
"lesson-speaker",
|
71 |
-
StaticFileEndpoint::new("./static/index.html"),
|
72 |
-
)
|
73 |
-
.at(
|
74 |
-
"lesson-listener",
|
75 |
-
StaticFileEndpoint::new("./static/index.html"),
|
76 |
-
)
|
77 |
-
.data(ctx);
|
78 |
-
let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
|
79 |
-
let listener = TcpListener::bind(addr);
|
80 |
-
let server = Server::new(listener);
|
81 |
|
82 |
select! {
|
83 |
-
res =
|
84 |
_ = tokio::signal::ctrl_c() => {
|
85 |
tracing::info!("Shutting down");
|
86 |
Ok(())
|
87 |
},
|
88 |
}
|
89 |
}
|
90 |
-
|
91 |
-
#[derive(Deserialize, Debug)]
|
92 |
-
pub struct LessonSpeakerQuery {
|
93 |
-
id: u32,
|
94 |
-
language: String,
|
95 |
-
#[allow(dead_code)] // TODO: use this in the future
|
96 |
-
prompt: Option<String>,
|
97 |
-
}
|
98 |
-
|
99 |
-
#[handler]
|
100 |
-
async fn stream_speaker(
|
101 |
-
ctx: Data<&Context>,
|
102 |
-
query: Query<LessonSpeakerQuery>,
|
103 |
-
ws: WebSocket,
|
104 |
-
) -> impl IntoResponse {
|
105 |
-
let lessons_manager = ctx.lessons_manager.clone();
|
106 |
-
ws.on_upgrade(|mut socket| async move {
|
107 |
-
let Ok(lang) = query.language.parse::<LanguageCode>() else {
|
108 |
-
let _ = socket
|
109 |
-
.send(Message::Text(format!("invalid language code: {}", query.language)))
|
110 |
-
.await;
|
111 |
-
return
|
112 |
-
};
|
113 |
-
let lesson = lessons_manager
|
114 |
-
.create_lesson(
|
115 |
-
query.id,
|
116 |
-
AsrEngine::AWS,
|
117 |
-
lang,
|
118 |
-
)
|
119 |
-
.await;
|
120 |
-
|
121 |
-
let mut transcribe_rx = lesson.transcript_channel();
|
122 |
-
let fut = async {
|
123 |
-
loop {
|
124 |
-
select! {
|
125 |
-
msg = socket.next() => {
|
126 |
-
let Some(res) = msg else { break };
|
127 |
-
let msg = res?;
|
128 |
-
if msg.is_close() {
|
129 |
-
break
|
130 |
-
}
|
131 |
-
let Message::Binary(bin) = msg else {
|
132 |
-
tracing::warn!("Other: {:?}", msg);
|
133 |
-
continue
|
134 |
-
};
|
135 |
-
let frame = u8_to_i16(&bin);
|
136 |
-
lesson.send(frame).await?
|
137 |
-
},
|
138 |
-
output = transcribe_rx.recv() => {
|
139 |
-
let evt = output?;
|
140 |
-
if evt.is_final {
|
141 |
-
tracing::trace!("Transcribed: {}", evt.transcript);
|
142 |
-
}
|
143 |
-
let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
|
144 |
-
let Ok(json) = serde_json::to_string(&evt) else {
|
145 |
-
tracing::warn!("failed to serialize json: {:?}", evt);
|
146 |
-
continue
|
147 |
-
};
|
148 |
-
socket.send(Message::Text(json)).await?
|
149 |
-
},
|
150 |
-
}
|
151 |
-
}
|
152 |
-
Ok(())
|
153 |
-
};
|
154 |
-
|
155 |
-
let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
|
156 |
-
let _ = span.enter();
|
157 |
-
let res: anyhow::Result<()> = fut.await;
|
158 |
-
match res {
|
159 |
-
Ok(()) => {
|
160 |
-
tracing::info!("lesson speaker closed");
|
161 |
-
}
|
162 |
-
Err(e) => {
|
163 |
-
tracing::warn!("lesson speaker error: {}", e);
|
164 |
-
}
|
165 |
-
}
|
166 |
-
})
|
167 |
-
}
|
168 |
-
|
169 |
-
#[derive(Deserialize, Debug)]
|
170 |
-
pub struct LessonListenerQuery {
|
171 |
-
id: u32,
|
172 |
-
language: String,
|
173 |
-
voice: String,
|
174 |
-
}
|
175 |
-
|
176 |
-
#[derive(Serialize, Debug)]
|
177 |
-
#[serde(tag = "type")]
|
178 |
-
enum LiveLessonTextEvent {
|
179 |
-
#[serde(rename = "original")]
|
180 |
-
Transcription {
|
181 |
-
content: String,
|
182 |
-
#[serde(rename = "isFinal")]
|
183 |
-
is_final: bool
|
184 |
-
},
|
185 |
-
Translation { content: String },
|
186 |
-
LipSync { visemes: Vec<Viseme> },
|
187 |
-
}
|
188 |
-
#[handler]
|
189 |
-
async fn stream_listener(
|
190 |
-
ctx: Data<&Context>,
|
191 |
-
query: Query<LessonListenerQuery>,
|
192 |
-
ws: WebSocket,
|
193 |
-
) -> impl IntoResponse {
|
194 |
-
let lessons_manager = ctx.lessons_manager.clone();
|
195 |
-
|
196 |
-
ws.on_upgrade(|mut socket| async move {
|
197 |
-
let voice_id = match query.voice.parse() {
|
198 |
-
Ok(id) => id,
|
199 |
-
Err(e) => {
|
200 |
-
let _ = socket
|
201 |
-
.send(Message::Text(format!("invalid voice id: {}", e)))
|
202 |
-
.await;
|
203 |
-
return;
|
204 |
-
}
|
205 |
-
};
|
206 |
-
|
207 |
-
let lesson_opt = lessons_manager.get_lesson(query.id).await;
|
208 |
-
debug!("listener param = {:?}", query);
|
209 |
-
let Some(lesson) = lesson_opt else {
|
210 |
-
let _ = socket
|
211 |
-
.send(Message::Text("lesson not found".to_string()))
|
212 |
-
.await;
|
213 |
-
return;
|
214 |
-
};
|
215 |
-
let mut transcript_rx = lesson.transcript_channel();
|
216 |
-
let mut lang_lesson = lesson.get_or_init(query.language.clone()).await;
|
217 |
-
let mut translate_rx = lang_lesson.translated_channel();
|
218 |
-
let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
|
219 |
-
let mut voice_rx = voice_lesson.voice_channel();
|
220 |
-
let mut lip_sync_rx = voice_lesson.lip_sync_channel();
|
221 |
-
|
222 |
-
let fut = async {
|
223 |
-
loop {
|
224 |
-
select! {
|
225 |
-
transcript_poll = transcript_rx.recv() => {
|
226 |
-
let transcript = transcript_poll?;
|
227 |
-
let evt = LiveLessonTextEvent::Transcription {
|
228 |
-
content: transcript.transcript,
|
229 |
-
is_final: transcript.is_final
|
230 |
-
};
|
231 |
-
let Ok(json) = serde_json::to_string(&evt) else {
|
232 |
-
tracing::warn!("failed to serialize: {:?}", evt);
|
233 |
-
continue
|
234 |
-
};
|
235 |
-
tracing::debug!("Transcribed: {}", json);
|
236 |
-
socket.send(Message::Text(json)).await?
|
237 |
-
},
|
238 |
-
translated_poll = translate_rx.recv() => {
|
239 |
-
let translated = translated_poll?;
|
240 |
-
let evt = LiveLessonTextEvent::Translation { content: translated };
|
241 |
-
let Ok(json) = serde_json::to_string(&evt) else {
|
242 |
-
tracing::warn!("failed to serialize: {:?}", evt);
|
243 |
-
continue
|
244 |
-
};
|
245 |
-
tracing::debug!("Translated: {}", json);
|
246 |
-
socket.send(Message::Text(json)).await?
|
247 |
-
},
|
248 |
-
voice_poll = voice_rx.recv() => {
|
249 |
-
let voice = voice_poll?;
|
250 |
-
socket.send(Message::Binary(voice)).await?
|
251 |
-
},
|
252 |
-
visemes_poll = lip_sync_rx.recv() => {
|
253 |
-
let visemes = visemes_poll?;
|
254 |
-
let evt = LiveLessonTextEvent::LipSync { visemes };
|
255 |
-
let Ok(json) = serde_json::to_string(&evt) else {
|
256 |
-
tracing::warn!("failed to serialize: {:?}", evt);
|
257 |
-
continue
|
258 |
-
};
|
259 |
-
socket.send(Message::Text(json)).await?
|
260 |
-
},
|
261 |
-
}
|
262 |
-
}
|
263 |
-
};
|
264 |
-
|
265 |
-
let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
|
266 |
-
let _ = span.enter();
|
267 |
-
let res: anyhow::Result<()> = fut.await;
|
268 |
-
match res {
|
269 |
-
Ok(()) => {
|
270 |
-
tracing::info!("lesson listener closed");
|
271 |
-
}
|
272 |
-
Err(e) => {
|
273 |
-
tracing::warn!("lesson listener error: {}", e);
|
274 |
-
}
|
275 |
-
}
|
276 |
-
})
|
277 |
-
}
|
278 |
-
|
279 |
-
#[derive(Serialize, Deserialize, Debug)]
|
280 |
-
enum SingleEvent {
|
281 |
-
#[serde(rename = "original")]
|
282 |
-
Transcription {
|
283 |
-
content: String,
|
284 |
-
#[serde(rename = "isFinal")]
|
285 |
-
is_final: bool
|
286 |
-
},
|
287 |
-
#[serde(rename = "translated")]
|
288 |
-
Translation { content: String },
|
289 |
-
#[serde(rename = "voice")]
|
290 |
-
Voice {
|
291 |
-
content: Base64Box
|
292 |
-
},
|
293 |
-
}
|
294 |
-
|
295 |
-
|
296 |
-
#[derive(Deserialize, Debug)]
|
297 |
-
pub struct SingleQuery {
|
298 |
-
id: u32,
|
299 |
-
from: String,
|
300 |
-
to: String,
|
301 |
-
voice: Option<String>,
|
302 |
-
}
|
303 |
-
|
304 |
-
|
305 |
-
#[handler]
|
306 |
-
async fn stream_single(
|
307 |
-
ctx: Data<&Context>,
|
308 |
-
query: Query<SingleQuery>,
|
309 |
-
ws: WebSocket
|
310 |
-
) -> impl IntoResponse {
|
311 |
-
let lessons_manager = ctx.lessons_manager.clone();
|
312 |
-
ws.on_upgrade(|mut socket| async move {
|
313 |
-
let Ok(lang) = query.from.parse::<LanguageCode>() else {
|
314 |
-
let _ = socket
|
315 |
-
.send(Message::Text(format!("invalid language code: {}", query.from)))
|
316 |
-
.await;
|
317 |
-
return
|
318 |
-
};
|
319 |
-
let lesson = lessons_manager
|
320 |
-
.create_lesson(
|
321 |
-
query.id,
|
322 |
-
AsrEngine::AWS,
|
323 |
-
lang,
|
324 |
-
)
|
325 |
-
.await;
|
326 |
-
|
327 |
-
let mut transcribe_rx = lesson.transcript_channel();
|
328 |
-
let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
|
329 |
-
let mut translate_rx = lang_lesson.translated_channel();
|
330 |
-
let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
|
331 |
-
let _ = socket
|
332 |
-
.send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
|
333 |
-
.await;
|
334 |
-
return
|
335 |
-
};
|
336 |
-
let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
|
337 |
-
let mut voice_rx = voice_lesson.voice_channel();
|
338 |
-
// let mut lip_sync_rx = voice_lesson.lip_sync_channel();
|
339 |
-
|
340 |
-
let fut = async {
|
341 |
-
loop {
|
342 |
-
let evt = select! {
|
343 |
-
input = socket.next() => {
|
344 |
-
let Some(res) = input else { break };
|
345 |
-
let msg = res?;
|
346 |
-
if msg.is_close() {
|
347 |
-
break
|
348 |
-
}
|
349 |
-
let Message::Binary(bin) = msg else {
|
350 |
-
tracing::warn!("Other: {:?}", msg);
|
351 |
-
continue
|
352 |
-
};
|
353 |
-
let frame = u8_to_i16(&bin);
|
354 |
-
lesson.send(frame).await?;
|
355 |
-
continue
|
356 |
-
},
|
357 |
-
transcript_poll = transcribe_rx.recv() => {
|
358 |
-
let evt = transcript_poll?;
|
359 |
-
if evt.is_final {
|
360 |
-
tracing::trace!("Transcribed: {}", evt.transcript);
|
361 |
-
}
|
362 |
-
SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
|
363 |
-
},
|
364 |
-
translated_poll = translate_rx.recv() => {
|
365 |
-
let translated = translated_poll?;
|
366 |
-
SingleEvent::Translation { content: translated }
|
367 |
-
},
|
368 |
-
voice_poll = voice_rx.recv() => {
|
369 |
-
let voice = voice_poll?;
|
370 |
-
SingleEvent::Voice { content: Base64Box(voice) }
|
371 |
-
},
|
372 |
-
};
|
373 |
-
|
374 |
-
let Ok(json) = serde_json::to_string(&evt) else {
|
375 |
-
tracing::warn!("failed to serialize json: {:?}", evt);
|
376 |
-
continue
|
377 |
-
};
|
378 |
-
socket.send(Message::Text(json)).await?
|
379 |
-
}
|
380 |
-
Ok(())
|
381 |
-
};
|
382 |
-
|
383 |
-
let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
|
384 |
-
let _ = span.enter();
|
385 |
-
let res: anyhow::Result<()> = fut.await;
|
386 |
-
match res {
|
387 |
-
Ok(()) => {
|
388 |
-
tracing::info!("lesson speaker closed");
|
389 |
-
}
|
390 |
-
Err(e) => {
|
391 |
-
tracing::warn!("lesson speaker error: {}", e);
|
392 |
-
}
|
393 |
-
}
|
394 |
-
})
|
395 |
-
}
|
396 |
-
|
397 |
-
fn u8_to_i16(input: &[u8]) -> Vec<i16> {
|
398 |
-
input
|
399 |
-
.chunks_exact(2)
|
400 |
-
.map(|chunk| {
|
401 |
-
let mut buf = [0u8; 2];
|
402 |
-
buf.copy_from_slice(chunk);
|
403 |
-
i16::from_le_bytes(buf)
|
404 |
-
})
|
405 |
-
.collect::<Vec<i16>>()
|
406 |
-
}
|
407 |
-
|
408 |
-
#[cfg(test)]
|
409 |
-
mod test {
|
410 |
-
use std::time::Duration;
|
411 |
-
use async_stream::stream;
|
412 |
-
use poem::listener::{Acceptor, Listener};
|
413 |
-
use tokio::pin;
|
414 |
-
use tokio::time::sleep;
|
415 |
-
use tokio_stream::StreamExt;
|
416 |
-
use tokio_tungstenite::{
|
417 |
-
connect_async,
|
418 |
-
tungstenite::Message,
|
419 |
-
};
|
420 |
-
use tracing::{info, error};
|
421 |
-
use crate::asr::slice_i16_to_u8;
|
422 |
-
use super::*;
|
423 |
-
|
424 |
-
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
425 |
-
#[tracing_test::traced_test]
|
426 |
-
async fn test_single() {
|
427 |
-
let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
|
428 |
-
let ctx = Context {
|
429 |
-
lessons_manager: LessonsManager::new(&shared_config),
|
430 |
-
};
|
431 |
-
|
432 |
-
let query = SingleQuery {
|
433 |
-
id: 1,
|
434 |
-
from: "zh-CN".to_string(),
|
435 |
-
to: "en-US".to_string(),
|
436 |
-
voice: None,
|
437 |
-
};
|
438 |
-
|
439 |
-
let acceptor = TcpListener::bind("[::]:0")
|
440 |
-
.into_acceptor()
|
441 |
-
.await
|
442 |
-
.unwrap();
|
443 |
-
let addr = acceptor
|
444 |
-
.local_addr()
|
445 |
-
.remove(0)
|
446 |
-
.as_socket_addr()
|
447 |
-
.cloned()
|
448 |
-
.unwrap();
|
449 |
-
let server = Server::new_with_acceptor(acceptor);
|
450 |
-
let handle = tokio::spawn(async move {
|
451 |
-
let _ = server.run(
|
452 |
-
Route::new()
|
453 |
-
.at("/ws/voice", get(stream_single))
|
454 |
-
.data(ctx)
|
455 |
-
).await;
|
456 |
-
});
|
457 |
-
|
458 |
-
let url = format!(
|
459 |
-
"ws://{}/ws/voice?id={}&from={}&to={}&voice={}",
|
460 |
-
addr, query.id, query.from, query.to, query.voice.unwrap_or("Amy".to_string())
|
461 |
-
);
|
462 |
-
let (mut client_stream, _) = connect_async(url)
|
463 |
-
.await
|
464 |
-
.unwrap();
|
465 |
-
|
466 |
-
client_stream
|
467 |
-
.send(Message::Binary(Vec::new()))
|
468 |
-
.await
|
469 |
-
.unwrap();
|
470 |
-
|
471 |
-
|
472 |
-
let wav = hound::WavReader::open("whisper/samples/samples_jfk.wav")
|
473 |
-
.expect("failed to open wav");
|
474 |
-
let spec = wav.spec();
|
475 |
-
info!("{:?}", spec);
|
476 |
-
let samples = wav
|
477 |
-
.into_samples::<i16>()
|
478 |
-
.map(|s| s.unwrap())
|
479 |
-
.collect::<Vec<i16>>();
|
480 |
-
let chunks = samples.chunks(1600)
|
481 |
-
.map(|chunk| chunk.to_vec())
|
482 |
-
.into_iter();
|
483 |
-
|
484 |
-
let audio_stream = stream! {
|
485 |
-
for chunk in chunks {
|
486 |
-
yield slice_i16_to_u8(&chunk);
|
487 |
-
sleep(Duration::from_millis(10)).await;
|
488 |
-
}
|
489 |
-
};
|
490 |
-
pin!(audio_stream);
|
491 |
-
|
492 |
-
let recv_fut = async {
|
493 |
-
while let Some(voice_slice) = audio_stream.next().await {
|
494 |
-
client_stream.send(Message::Binary(voice_slice)).await?;
|
495 |
-
}
|
496 |
-
info!("sent all voice chunks");
|
497 |
-
|
498 |
-
while let Some(next_msg) = client_stream.next().await {
|
499 |
-
debug!(?next_msg);
|
500 |
-
let msg = next_msg?;
|
501 |
-
let Message::Text(json_str) = msg else { continue };
|
502 |
-
let Ok(evt) = serde_json::from_str::<SingleEvent>(&json_str) else { continue };
|
503 |
-
if let SingleEvent::Voice { .. } = evt {
|
504 |
-
return Ok(())
|
505 |
-
}
|
506 |
-
}
|
507 |
-
|
508 |
-
Ok(()) as anyhow::Result<()>
|
509 |
-
};
|
510 |
-
|
511 |
-
select! {
|
512 |
-
res = recv_fut => {
|
513 |
-
if let Err(e) = res {
|
514 |
-
error!("Error: {:?}", e);
|
515 |
-
assert!(false, "Error: {}", e);
|
516 |
-
}
|
517 |
-
}
|
518 |
-
_ = sleep(Duration::from_secs(10)) => {
|
519 |
-
assert!(false, "timeout");
|
520 |
-
}
|
521 |
-
};
|
522 |
-
|
523 |
-
handle.abort();
|
524 |
-
}
|
525 |
-
}
|
|
|
3 |
* SPDX-License-Identifier: Apache-2.0.
|
4 |
*/
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
use aws_config::BehaviorVersion;
|
7 |
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
use tokio::select;
|
9 |
+
use tracing::{debug};
|
10 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
11 |
|
12 |
+
use polyhedron::app;
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
#[tokio::main]
|
16 |
async fn main() -> Result<(), std::io::Error> {
|
|
|
22 |
debug!("Transcribe client version: {}", PKG_VERSION);
|
23 |
|
24 |
let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
select! {
|
28 |
+
res = app(&shared_config) => res,
|
29 |
_ = tokio::signal::ctrl_c() => {
|
30 |
tracing::info!("Shutting down");
|
31 |
Ok(())
|
32 |
},
|
33 |
}
|
34 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/stream_single.rs
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
use aws_config::BehaviorVersion;
|
3 |
+
use std::time::Duration;
|
4 |
+
use async_stream::stream;
|
5 |
+
use poem::{
|
6 |
+
get,
|
7 |
+
listener::{Listener, Acceptor, TcpListener},
|
8 |
+
EndpointExt, Route, Server,
|
9 |
+
};
|
10 |
+
use tokio::{pin, select};
|
11 |
+
use tokio::time::sleep;
|
12 |
+
use tokio_stream::StreamExt;
|
13 |
+
use tokio_tungstenite::{
|
14 |
+
connect_async,
|
15 |
+
tungstenite::Message,
|
16 |
+
};
|
17 |
+
use futures_util::sink::SinkExt;
|
18 |
+
use tracing::{info, error, debug};
|
19 |
+
use polyhedron::{
|
20 |
+
stream_single,
|
21 |
+
SingleEvent,
|
22 |
+
asr::slice_i16_to_u8,
|
23 |
+
Context
|
24 |
+
};
|
25 |
+
|
26 |
+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
27 |
+
#[tracing_test::traced_test]
|
28 |
+
async fn test_single() {
|
29 |
+
let shared_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
|
30 |
+
let ctx = Context::new(&shared_config);
|
31 |
+
|
32 |
+
let acceptor = TcpListener::bind("[::]:0")
|
33 |
+
.into_acceptor()
|
34 |
+
.await
|
35 |
+
.unwrap();
|
36 |
+
let addr = acceptor
|
37 |
+
.local_addr()
|
38 |
+
.remove(0)
|
39 |
+
.as_socket_addr()
|
40 |
+
.cloned()
|
41 |
+
.unwrap();
|
42 |
+
let server = Server::new_with_acceptor(acceptor);
|
43 |
+
let handle = tokio::spawn(async move {
|
44 |
+
let _ = server.run(
|
45 |
+
Route::new()
|
46 |
+
.at("/ws/voice", get(stream_single))
|
47 |
+
.data(ctx)
|
48 |
+
).await;
|
49 |
+
});
|
50 |
+
|
51 |
+
let url = format!(
|
52 |
+
"ws://{}/ws/voice?id=123abc&from=zh-CN&to=en-US&voice=Amy", addr
|
53 |
+
);
|
54 |
+
let (mut client_stream, _) = connect_async(url)
|
55 |
+
.await
|
56 |
+
.unwrap();
|
57 |
+
|
58 |
+
client_stream
|
59 |
+
.send(Message::Binary(Vec::new()))
|
60 |
+
.await
|
61 |
+
.unwrap();
|
62 |
+
|
63 |
+
|
64 |
+
let wav = hound::WavReader::open("whisper/samples/samples_jfk.wav")
|
65 |
+
.expect("failed to open wav");
|
66 |
+
let spec = wav.spec();
|
67 |
+
info!("{:?}", spec);
|
68 |
+
let samples = wav
|
69 |
+
.into_samples::<i16>()
|
70 |
+
.map(|s| s.unwrap())
|
71 |
+
.collect::<Vec<i16>>();
|
72 |
+
let chunks = samples.chunks(1600)
|
73 |
+
.map(|chunk| chunk.to_vec())
|
74 |
+
.into_iter();
|
75 |
+
|
76 |
+
let audio_stream = stream! {
|
77 |
+
for chunk in chunks {
|
78 |
+
yield slice_i16_to_u8(&chunk);
|
79 |
+
sleep(Duration::from_millis(10)).await;
|
80 |
+
}
|
81 |
+
};
|
82 |
+
pin!(audio_stream);
|
83 |
+
|
84 |
+
let recv_fut = async {
|
85 |
+
while let Some(voice_slice) = audio_stream.next().await {
|
86 |
+
client_stream.send(Message::Binary(voice_slice)).await?;
|
87 |
+
}
|
88 |
+
info!("sent all voice chunks");
|
89 |
+
|
90 |
+
while let Some(Ok(msg)) = client_stream.next().await {
|
91 |
+
debug!("recv: {:?}", msg);
|
92 |
+
let Message::Text(json_str) = msg else { continue };
|
93 |
+
let Ok(evt) = serde_json::from_str::<SingleEvent>(&json_str) else { continue };
|
94 |
+
if let SingleEvent::Voice { .. } = evt {
|
95 |
+
return Ok(())
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
Ok(()) as anyhow::Result<()>
|
100 |
+
};
|
101 |
+
|
102 |
+
select! {
|
103 |
+
res = recv_fut => {
|
104 |
+
if let Err(e) = res {
|
105 |
+
error!("Error: {:?}", e);
|
106 |
+
assert!(false, "Error: {}", e);
|
107 |
+
}
|
108 |
+
}
|
109 |
+
_ = sleep(Duration::from_secs(10)) => {
|
110 |
+
assert!(false, "timeout");
|
111 |
+
}
|
112 |
+
};
|
113 |
+
|
114 |
+
handle.abort();
|
115 |
+
}
|
whisper/src/handler.rs
CHANGED
@@ -445,7 +445,6 @@ mod test {
|
|
445 |
|
446 |
let mut rx = whisper_handler.subscribe();
|
447 |
let send_fut = async {
|
448 |
-
// tokio::time::sleep(Duration::from_secs(5)).await;
|
449 |
for chunk in chunks {
|
450 |
let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample");
|
451 |
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
|
445 |
|
446 |
let mut rx = whisper_handler.subscribe();
|
447 |
let send_fut = async {
|
|
|
448 |
for chunk in chunks {
|
449 |
let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample");
|
450 |
tokio::time::sleep(Duration::from_millis(100)).await;
|