Spaces:
Sleeping
Sleeping
2nd commit
Browse files- src/main.rs +162 -79
src/main.rs
CHANGED
@@ -5,24 +5,30 @@
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
|
|
8 |
use std::default::Default;
|
|
|
|
|
|
|
9 |
use tokio::sync::mpsc::channel;
|
10 |
use async_stream::stream;
|
11 |
use aws_config::meta::region::RegionProviderChain;
|
12 |
use aws_sdk_transcribestreaming::primitives::Blob;
|
13 |
use aws_sdk_transcribestreaming::types::{AudioStream, AudioEvent, LanguageCode, MediaEncoding, TranscriptResultStream};
|
14 |
-
use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION
|
15 |
-
use
|
16 |
use clap::Parser;
|
17 |
|
18 |
use poem::{handler, listener::TcpListener, Server, get, Route, IntoResponse, Endpoint, EndpointExt};
|
19 |
-
use futures_util::{Sink, SinkExt};
|
20 |
use poem::endpoint::StaticFilesEndpoint;
|
21 |
use poem::web::websocket::{Message, WebSocket};
|
22 |
use futures_util::stream::StreamExt;
|
|
|
23 |
|
24 |
-
use tokio::select;
|
25 |
use tokio::sync::mpsc::{Receiver, Sender};
|
|
|
26 |
|
27 |
|
28 |
#[derive(Debug, Parser)]
|
@@ -46,7 +52,6 @@ enum ReplyEvent {
|
|
46 |
Synthesized(Vec<u8>),
|
47 |
}
|
48 |
|
49 |
-
const CHUNK_SIZE: usize = 8192;
|
50 |
|
51 |
/// Transcribes an audio file to text.
|
52 |
/// # Arguments
|
@@ -58,35 +63,11 @@ const CHUNK_SIZE: usize = 8192;
|
|
58 |
/// If not supplied, uses the value of the **AWS_REGION** environment variable.
|
59 |
/// If the environment variable is not set, defaults to **us-west-2**.
|
60 |
/// * `[-v]` - Whether to display additional information.
|
61 |
-
async fn stream_process(
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
// audio_file,
|
67 |
-
verbose,
|
68 |
-
} = Opt::parse();
|
69 |
-
|
70 |
-
let region_provider = RegionProviderChain::first_try(region.map(Region::new))
|
71 |
-
.or_default_provider()
|
72 |
-
.or_else(Region::new("us-west-2"));
|
73 |
-
|
74 |
-
println!();
|
75 |
-
|
76 |
-
if verbose {
|
77 |
-
println!("Transcribe client version: {}", PKG_VERSION);
|
78 |
-
println!(
|
79 |
-
"Region: {}",
|
80 |
-
region_provider.region().await.unwrap().as_ref()
|
81 |
-
);
|
82 |
-
// println!("Audio filename: {}", &audio_file);
|
83 |
-
println!();
|
84 |
-
}
|
85 |
-
|
86 |
-
let shared_config = aws_config::from_env().region(region_provider).load().await;
|
87 |
-
let transcript_client = Client::new(&shared_config);
|
88 |
-
let translate_client = aws_sdk_translate::Client::new(&shared_config);
|
89 |
-
let polly_client = aws_sdk_polly::Client::new(&shared_config);
|
90 |
|
91 |
let input_stream = stream! {
|
92 |
while let Some(raw) = rx.recv().await {
|
@@ -94,47 +75,24 @@ async fn stream_process(mut rx: Receiver<Vec<u8>>, tx: Sender<ReplyEvent>) -> Re
|
|
94 |
}
|
95 |
};
|
96 |
|
97 |
-
let
|
98 |
.start_stream_transcription()
|
99 |
.language_code(LanguageCode::ZhCn)//LanguageCode::EnGb
|
100 |
.media_sample_rate_hertz(16000)
|
101 |
.media_encoding(MediaEncoding::Pcm)
|
102 |
.audio_stream(input_stream.into())
|
103 |
.send()
|
104 |
-
.await
|
105 |
-
|
106 |
-
while let Some(event) = output.transcript_result_stream.recv().await? {
|
107 |
-
match event {
|
108 |
-
TranscriptResultStream::TranscriptEvent(transcript_event) => {
|
109 |
-
let transcript = transcript_event.transcript.expect("transcript");
|
110 |
-
for result in transcript.results.unwrap_or_default() {
|
111 |
-
if result.is_partial {
|
112 |
-
if verbose {
|
113 |
-
println!("Partial: {:?}", result);
|
114 |
-
}
|
115 |
-
} else {
|
116 |
-
let first_alternative = &result.alternatives.as_ref().expect("should have")[0];
|
117 |
-
let slice = first_alternative.transcript.as_ref().expect("should have");
|
118 |
-
println!("Line: {:?}", slice);
|
119 |
-
tx.send(ReplyEvent::Transcribed(slice.clone())).await.expect("failed to send");
|
120 |
-
let lc = result.language_code.as_ref().map(|lc| lc.as_str().to_string());
|
121 |
-
let translated = translate(&translate_client, first_alternative.transcript.clone(), lc).await;
|
122 |
-
if let Some(has) = translated {
|
123 |
-
tx.send(ReplyEvent::Transcribed(has.clone())).await.expect("failed to send");
|
124 |
-
println!("Translated: {}", has);
|
125 |
-
if let Some(synthesized) = synthesize(&polly_client, has).await {
|
126 |
-
tx.send(ReplyEvent::Synthesized(synthesized)).await.expect("failed to send");
|
127 |
-
}
|
128 |
-
}
|
129 |
|
130 |
-
|
131 |
-
}
|
132 |
-
}
|
133 |
-
otherwise => panic!("received unexpected event type: {:?}", otherwise),
|
134 |
-
}
|
135 |
-
}
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
}
|
139 |
|
140 |
async fn translate(client: &aws_sdk_translate::Client, transcript: Option<String>, source_lang_code: Option<String>) -> Option<String> {
|
@@ -158,24 +116,63 @@ async fn synthesize(client: &aws_sdk_polly::Client, transcript: String) -> Optio
|
|
158 |
Some(bs.to_vec())
|
159 |
}
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
-
let mut pcm: Vec<u8> = Vec::new();
|
166 |
-
for sample in samples_result.unwrap() {
|
167 |
-
pcm.put_i16_le(sample);
|
168 |
-
}
|
169 |
-
pcm
|
170 |
}
|
171 |
|
172 |
#[tokio::main]
|
173 |
-
async fn main() -> Result<(), std::io::Error> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
"/",
|
175 |
StaticFilesEndpoint::new("./static")
|
176 |
.show_files_listing()
|
177 |
.index_file("index.html"),
|
178 |
-
).at("/translate", get(stream_translate))
|
|
|
179 |
let listener = TcpListener::bind("[::]:8080");
|
180 |
let server = Server::new(listener);
|
181 |
|
@@ -184,11 +181,19 @@ async fn main() -> Result<(), std::io::Error> { let app = Route::new().nest(
|
|
184 |
|
185 |
|
186 |
#[handler]
|
187 |
-
async fn stream_translate(ws: WebSocket) -> impl IntoResponse {
|
|
|
|
|
|
|
188 |
ws.on_upgrade(|mut socket| async move {
|
189 |
let (origin_tx, origin_rx) = channel::<Vec<u8>>(128);
|
190 |
let (translate_tx, mut translate_rx) = channel::<ReplyEvent>(128);
|
191 |
-
let stream_fut = stream_process(
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
let ws_fut = async {
|
194 |
loop {
|
@@ -239,3 +244,81 @@ async fn stream_translate(ws: WebSocket) -> impl IntoResponse {
|
|
239 |
})
|
240 |
}
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
8 |
+
use std::collections::BTreeMap;
|
9 |
use std::default::Default;
|
10 |
+
use std::error::Error;
|
11 |
+
use std::fmt::{Debug, Display, Formatter};
|
12 |
+
use std::future::Future;
|
13 |
use tokio::sync::mpsc::channel;
|
14 |
use async_stream::stream;
|
15 |
use aws_config::meta::region::RegionProviderChain;
|
16 |
use aws_sdk_transcribestreaming::primitives::Blob;
|
17 |
use aws_sdk_transcribestreaming::types::{AudioStream, AudioEvent, LanguageCode, MediaEncoding, TranscriptResultStream};
|
18 |
+
use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION};
|
19 |
+
use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput;
|
20 |
use clap::Parser;
|
21 |
|
22 |
use poem::{handler, listener::TcpListener, Server, get, Route, IntoResponse, Endpoint, EndpointExt};
|
23 |
+
use futures_util::{Sink, SinkExt, TryFutureExt, TryStreamExt};
|
24 |
use poem::endpoint::StaticFilesEndpoint;
|
25 |
use poem::web::websocket::{Message, WebSocket};
|
26 |
use futures_util::stream::StreamExt;
|
27 |
+
use poem::web::Data;
|
28 |
|
29 |
+
use tokio::{select};
|
30 |
use tokio::sync::mpsc::{Receiver, Sender};
|
31 |
+
use tokio_stream::Stream;
|
32 |
|
33 |
|
34 |
#[derive(Debug, Parser)]
|
|
|
52 |
Synthesized(Vec<u8>),
|
53 |
}
|
54 |
|
|
|
55 |
|
56 |
/// Transcribes an audio file to text.
|
57 |
/// # Arguments
|
|
|
63 |
/// If not supplied, uses the value of the **AWS_REGION** environment variable.
|
64 |
/// If the environment variable is not set, defaults to **us-west-2**.
|
65 |
/// * `[-v]` - Whether to display additional information.
|
66 |
+
async fn stream_process(translate_client: aws_sdk_translate::Client,
|
67 |
+
polly_client: aws_sdk_polly::Client,
|
68 |
+
transcript_client: aws_sdk_transcribestreaming::Client,
|
69 |
+
mut rx: Receiver<Vec<u8>>,
|
70 |
+
tx: Sender<ReplyEvent>) -> Result<(), StreamTranscriptionError> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
let input_stream = stream! {
|
73 |
while let Some(raw) = rx.recv().await {
|
|
|
75 |
}
|
76 |
};
|
77 |
|
78 |
+
let output = transcript_client
|
79 |
.start_stream_transcription()
|
80 |
.language_code(LanguageCode::ZhCn)//LanguageCode::EnGb
|
81 |
.media_sample_rate_hertz(16000)
|
82 |
.media_encoding(MediaEncoding::Pcm)
|
83 |
.audio_stream(input_stream.into())
|
84 |
.send()
|
85 |
+
.await
|
86 |
+
.map_err(|e| StreamTranscriptionError::EstablishStreamError(Box::new(e)))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
let output_stream = to_stream(output);
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
output_stream
|
91 |
+
.flat_map(|res| {
|
92 |
+
process(translate_client.clone(), polly_client.clone(), res)
|
93 |
+
})
|
94 |
+
.try_for_each(|reply| tx.send(reply).map_err(|e| StreamTranscriptionError::Shutdown))
|
95 |
+
.await
|
96 |
}
|
97 |
|
98 |
async fn translate(client: &aws_sdk_translate::Client, transcript: Option<String>, source_lang_code: Option<String>) -> Option<String> {
|
|
|
116 |
Some(bs.to_vec())
|
117 |
}
|
118 |
|
119 |
+
#[derive(Clone)]
|
120 |
+
struct Context {
|
121 |
+
translate_client: aws_sdk_translate::Client,
|
122 |
+
polly_client: aws_sdk_polly::Client,
|
123 |
+
transcript_client: aws_sdk_transcribestreaming::Client,
|
124 |
+
}
|
125 |
+
|
126 |
+
struct Lessons {
|
127 |
+
lessons: BTreeMap<u32, Lesson>
|
128 |
+
}
|
129 |
+
|
130 |
+
struct Lesson {
|
131 |
|
|
|
|
|
|
|
|
|
|
|
132 |
}
|
133 |
|
134 |
#[tokio::main]
|
135 |
+
async fn main() -> Result<(), std::io::Error> {
|
136 |
+
tracing_subscriber::fmt::init();
|
137 |
+
|
138 |
+
let Opt {
|
139 |
+
region,
|
140 |
+
verbose,
|
141 |
+
} = Opt::parse();
|
142 |
+
|
143 |
+
let region_provider = RegionProviderChain::first_try(region.map(Region::new))
|
144 |
+
.or_default_provider()
|
145 |
+
.or_else(Region::new("us-west-2"));
|
146 |
+
|
147 |
+
println!();
|
148 |
+
|
149 |
+
if verbose {
|
150 |
+
println!("Transcribe client version: {}", PKG_VERSION);
|
151 |
+
println!(
|
152 |
+
"Region: {}",
|
153 |
+
region_provider.region().await.unwrap().as_ref()
|
154 |
+
);
|
155 |
+
// println!("Audio filename: {}", &audio_file);
|
156 |
+
println!();
|
157 |
+
}
|
158 |
+
|
159 |
+
let shared_config = aws_config::from_env().region(region_provider).load().await;
|
160 |
+
let transcript_client = aws_sdk_transcribestreaming::Client::new(&shared_config);
|
161 |
+
let translate_client = aws_sdk_translate::Client::new(&shared_config);
|
162 |
+
let polly_client = aws_sdk_polly::Client::new(&shared_config);
|
163 |
+
let ctx = Context {
|
164 |
+
translate_client,
|
165 |
+
polly_client,
|
166 |
+
transcript_client,
|
167 |
+
};
|
168 |
+
|
169 |
+
let app = Route::new().nest(
|
170 |
"/",
|
171 |
StaticFilesEndpoint::new("./static")
|
172 |
.show_files_listing()
|
173 |
.index_file("index.html"),
|
174 |
+
).at("/translate", get(stream_translate))
|
175 |
+
.data(ctx);
|
176 |
let listener = TcpListener::bind("[::]:8080");
|
177 |
let server = Server::new(listener);
|
178 |
|
|
|
181 |
|
182 |
|
183 |
#[handler]
|
184 |
+
async fn stream_translate(ctx: Data<&Context>, ws: WebSocket) -> impl IntoResponse {
|
185 |
+
let translate_client = ctx.translate_client.clone();
|
186 |
+
let polly_client = ctx.polly_client.clone();
|
187 |
+
let transcript_client = ctx.transcript_client.clone();
|
188 |
ws.on_upgrade(|mut socket| async move {
|
189 |
let (origin_tx, origin_rx) = channel::<Vec<u8>>(128);
|
190 |
let (translate_tx, mut translate_rx) = channel::<ReplyEvent>(128);
|
191 |
+
let stream_fut = stream_process(
|
192 |
+
translate_client,
|
193 |
+
polly_client,
|
194 |
+
transcript_client,
|
195 |
+
origin_rx,
|
196 |
+
translate_tx);
|
197 |
|
198 |
let ws_fut = async {
|
199 |
loop {
|
|
|
244 |
})
|
245 |
}
|
246 |
|
247 |
+
#[derive(Debug)]
|
248 |
+
enum StreamTranscriptionError {
|
249 |
+
EstablishStreamError(Box<dyn Error + Send + Sync>),
|
250 |
+
TranscriptResultStreamError(Box<dyn Error + Send + Sync>),
|
251 |
+
Shutdown,
|
252 |
+
Unknown
|
253 |
+
}
|
254 |
+
|
255 |
+
|
256 |
+
impl Display for StreamTranscriptionError {
|
257 |
+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
258 |
+
match self {
|
259 |
+
StreamTranscriptionError::EstablishStreamError(e) => write!(f, "EstablishStreamError: {}", e),
|
260 |
+
StreamTranscriptionError::TranscriptResultStreamError(e) => write!(f, "TranscriptResultStreamError: {}", e),
|
261 |
+
StreamTranscriptionError::Shutdown => write!(f, "Shutdown"),
|
262 |
+
StreamTranscriptionError::Unknown => write!(f, "Unknown"),
|
263 |
+
}
|
264 |
+
}
|
265 |
+
}
|
266 |
+
|
267 |
+
impl Error for StreamTranscriptionError {
|
268 |
+
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
269 |
+
match self {
|
270 |
+
StreamTranscriptionError::EstablishStreamError(e) => Some(e.as_ref()),
|
271 |
+
StreamTranscriptionError::TranscriptResultStreamError(e) => Some(e.as_ref()),
|
272 |
+
StreamTranscriptionError::Shutdown => None,
|
273 |
+
StreamTranscriptionError::Unknown => None,
|
274 |
+
}
|
275 |
+
}
|
276 |
+
}
|
277 |
+
|
278 |
+
fn to_stream(mut output: StartStreamTranscriptionOutput) -> impl Stream<Item=Result<String, StreamTranscriptionError>> {
|
279 |
+
stream! {
|
280 |
+
while let Some(event) = output
|
281 |
+
.transcript_result_stream
|
282 |
+
.recv()
|
283 |
+
.await
|
284 |
+
.map_err(|e| StreamTranscriptionError::TranscriptResultStreamError(Box::new(e)))? {
|
285 |
+
match event {
|
286 |
+
TranscriptResultStream::TranscriptEvent(transcript_event) => {
|
287 |
+
let transcript = transcript_event.transcript.expect("transcript");
|
288 |
+
for result in transcript.results.unwrap_or_default() {
|
289 |
+
if !result.is_partial {
|
290 |
+
let first_alternative = &result.alternatives.as_ref().expect("should have")[0];
|
291 |
+
let slice = first_alternative.transcript.as_ref().expect("should have");
|
292 |
+
yield Ok(slice.clone());
|
293 |
+
}
|
294 |
+
}
|
295 |
+
}
|
296 |
+
otherwise => yield Err(StreamTranscriptionError::Unknown),
|
297 |
+
}
|
298 |
+
}
|
299 |
+
}
|
300 |
+
}
|
301 |
+
|
302 |
+
fn process(translate_client: aws_sdk_translate::Client,
|
303 |
+
polly_client: aws_sdk_polly::Client,
|
304 |
+
res: Result<String, StreamTranscriptionError>) -> impl Stream<Item=Result<ReplyEvent, StreamTranscriptionError>> {
|
305 |
+
stream! {
|
306 |
+
match res {
|
307 |
+
Ok(transcription) => {
|
308 |
+
yield Ok(ReplyEvent::Transcribed(transcription.clone()));
|
309 |
+
let translated = translate(&translate_client, Some(transcription), Some("en".to_string())).await;
|
310 |
+
if let Some(has) = translated {
|
311 |
+
yield Ok(ReplyEvent::Translated(has.clone()));
|
312 |
+
println!("Translated: {}", has);
|
313 |
+
if let Some(synthesized) = synthesize(&polly_client, has).await {
|
314 |
+
yield Ok(ReplyEvent::Synthesized(synthesized));
|
315 |
+
}
|
316 |
+
}
|
317 |
+
},
|
318 |
+
Err(e) => {
|
319 |
+
yield Err(e);
|
320 |
+
}
|
321 |
+
}
|
322 |
+
|
323 |
+
}
|
324 |
+
}
|