mingyang91 commited on
Commit
4840c8f
·
verified ·
1 Parent(s): a446897

2nd commit

Browse files
Files changed (1) hide show
  1. src/main.rs +162 -79
src/main.rs CHANGED
@@ -5,24 +5,30 @@
5
 
6
  #![allow(clippy::result_large_err)]
7
 
 
8
  use std::default::Default;
 
 
 
9
  use tokio::sync::mpsc::channel;
10
  use async_stream::stream;
11
  use aws_config::meta::region::RegionProviderChain;
12
  use aws_sdk_transcribestreaming::primitives::Blob;
13
  use aws_sdk_transcribestreaming::types::{AudioStream, AudioEvent, LanguageCode, MediaEncoding, TranscriptResultStream};
14
- use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION, Client, Error};
15
- use bytes::BufMut;
16
  use clap::Parser;
17
 
18
  use poem::{handler, listener::TcpListener, Server, get, Route, IntoResponse, Endpoint, EndpointExt};
19
- use futures_util::{Sink, SinkExt};
20
  use poem::endpoint::StaticFilesEndpoint;
21
  use poem::web::websocket::{Message, WebSocket};
22
  use futures_util::stream::StreamExt;
 
23
 
24
- use tokio::select;
25
  use tokio::sync::mpsc::{Receiver, Sender};
 
26
 
27
 
28
  #[derive(Debug, Parser)]
@@ -46,7 +52,6 @@ enum ReplyEvent {
46
  Synthesized(Vec<u8>),
47
  }
48
 
49
- const CHUNK_SIZE: usize = 8192;
50
 
51
  /// Transcribes an audio file to text.
52
  /// # Arguments
@@ -58,35 +63,11 @@ const CHUNK_SIZE: usize = 8192;
58
  /// If not supplied, uses the value of the **AWS_REGION** environment variable.
59
  /// If the environment variable is not set, defaults to **us-west-2**.
60
  /// * `[-v]` - Whether to display additional information.
61
- async fn stream_process(mut rx: Receiver<Vec<u8>>, tx: Sender<ReplyEvent>) -> Result<(), Error> {
62
- tracing_subscriber::fmt::init();
63
-
64
- let Opt {
65
- region,
66
- // audio_file,
67
- verbose,
68
- } = Opt::parse();
69
-
70
- let region_provider = RegionProviderChain::first_try(region.map(Region::new))
71
- .or_default_provider()
72
- .or_else(Region::new("us-west-2"));
73
-
74
- println!();
75
-
76
- if verbose {
77
- println!("Transcribe client version: {}", PKG_VERSION);
78
- println!(
79
- "Region: {}",
80
- region_provider.region().await.unwrap().as_ref()
81
- );
82
- // println!("Audio filename: {}", &audio_file);
83
- println!();
84
- }
85
-
86
- let shared_config = aws_config::from_env().region(region_provider).load().await;
87
- let transcript_client = Client::new(&shared_config);
88
- let translate_client = aws_sdk_translate::Client::new(&shared_config);
89
- let polly_client = aws_sdk_polly::Client::new(&shared_config);
90
 
91
  let input_stream = stream! {
92
  while let Some(raw) = rx.recv().await {
@@ -94,47 +75,24 @@ async fn stream_process(mut rx: Receiver<Vec<u8>>, tx: Sender<ReplyEvent>) -> Re
94
  }
95
  };
96
 
97
- let mut output = transcript_client
98
  .start_stream_transcription()
99
  .language_code(LanguageCode::ZhCn)//LanguageCode::EnGb
100
  .media_sample_rate_hertz(16000)
101
  .media_encoding(MediaEncoding::Pcm)
102
  .audio_stream(input_stream.into())
103
  .send()
104
- .await?;
105
-
106
- while let Some(event) = output.transcript_result_stream.recv().await? {
107
- match event {
108
- TranscriptResultStream::TranscriptEvent(transcript_event) => {
109
- let transcript = transcript_event.transcript.expect("transcript");
110
- for result in transcript.results.unwrap_or_default() {
111
- if result.is_partial {
112
- if verbose {
113
- println!("Partial: {:?}", result);
114
- }
115
- } else {
116
- let first_alternative = &result.alternatives.as_ref().expect("should have")[0];
117
- let slice = first_alternative.transcript.as_ref().expect("should have");
118
- println!("Line: {:?}", slice);
119
- tx.send(ReplyEvent::Transcribed(slice.clone())).await.expect("failed to send");
120
- let lc = result.language_code.as_ref().map(|lc| lc.as_str().to_string());
121
- let translated = translate(&translate_client, first_alternative.transcript.clone(), lc).await;
122
- if let Some(has) = translated {
123
- tx.send(ReplyEvent::Transcribed(has.clone())).await.expect("failed to send");
124
- println!("Translated: {}", has);
125
- if let Some(synthesized) = synthesize(&polly_client, has).await {
126
- tx.send(ReplyEvent::Synthesized(synthesized)).await.expect("failed to send");
127
- }
128
- }
129
 
130
- }
131
- }
132
- }
133
- otherwise => panic!("received unexpected event type: {:?}", otherwise),
134
- }
135
- }
136
 
137
- Ok(())
 
 
 
 
 
138
  }
139
 
140
  async fn translate(client: &aws_sdk_translate::Client, transcript: Option<String>, source_lang_code: Option<String>) -> Option<String> {
@@ -158,24 +116,63 @@ async fn synthesize(client: &aws_sdk_polly::Client, transcript: String) -> Optio
158
  Some(bs.to_vec())
159
  }
160
 
161
- fn pcm_data(audio_file: &str) -> Vec<u8> {
162
- let reader = hound::WavReader::open(audio_file).unwrap();
163
- let samples_result: hound::Result<Vec<i16>> = reader.into_samples::<i16>().collect();
 
 
 
 
 
 
 
 
 
164
 
165
- let mut pcm: Vec<u8> = Vec::new();
166
- for sample in samples_result.unwrap() {
167
- pcm.put_i16_le(sample);
168
- }
169
- pcm
170
  }
171
 
172
  #[tokio::main]
173
- async fn main() -> Result<(), std::io::Error> { let app = Route::new().nest(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  "/",
175
  StaticFilesEndpoint::new("./static")
176
  .show_files_listing()
177
  .index_file("index.html"),
178
- ).at("/translate", get(stream_translate));
 
179
  let listener = TcpListener::bind("[::]:8080");
180
  let server = Server::new(listener);
181
 
@@ -184,11 +181,19 @@ async fn main() -> Result<(), std::io::Error> { let app = Route::new().nest(
184
 
185
 
186
  #[handler]
187
- async fn stream_translate(ws: WebSocket) -> impl IntoResponse {
 
 
 
188
  ws.on_upgrade(|mut socket| async move {
189
  let (origin_tx, origin_rx) = channel::<Vec<u8>>(128);
190
  let (translate_tx, mut translate_rx) = channel::<ReplyEvent>(128);
191
- let stream_fut = stream_process(origin_rx, translate_tx);
 
 
 
 
 
192
 
193
  let ws_fut = async {
194
  loop {
@@ -239,3 +244,81 @@ async fn stream_translate(ws: WebSocket) -> impl IntoResponse {
239
  })
240
  }
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  #![allow(clippy::result_large_err)]
7
 
8
+ use std::collections::BTreeMap;
9
  use std::default::Default;
10
+ use std::error::Error;
11
+ use std::fmt::{Debug, Display, Formatter};
12
+ use std::future::Future;
13
  use tokio::sync::mpsc::channel;
14
  use async_stream::stream;
15
  use aws_config::meta::region::RegionProviderChain;
16
  use aws_sdk_transcribestreaming::primitives::Blob;
17
  use aws_sdk_transcribestreaming::types::{AudioStream, AudioEvent, LanguageCode, MediaEncoding, TranscriptResultStream};
18
+ use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION};
19
+ use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput;
20
  use clap::Parser;
21
 
22
  use poem::{handler, listener::TcpListener, Server, get, Route, IntoResponse, Endpoint, EndpointExt};
23
+ use futures_util::{Sink, SinkExt, TryFutureExt, TryStreamExt};
24
  use poem::endpoint::StaticFilesEndpoint;
25
  use poem::web::websocket::{Message, WebSocket};
26
  use futures_util::stream::StreamExt;
27
+ use poem::web::Data;
28
 
29
+ use tokio::{select};
30
  use tokio::sync::mpsc::{Receiver, Sender};
31
+ use tokio_stream::Stream;
32
 
33
 
34
  #[derive(Debug, Parser)]
 
52
  Synthesized(Vec<u8>),
53
  }
54
 
 
55
 
56
  /// Transcribes an audio file to text.
57
  /// # Arguments
 
63
  /// If not supplied, uses the value of the **AWS_REGION** environment variable.
64
  /// If the environment variable is not set, defaults to **us-west-2**.
65
  /// * `[-v]` - Whether to display additional information.
66
+ async fn stream_process(translate_client: aws_sdk_translate::Client,
67
+ polly_client: aws_sdk_polly::Client,
68
+ transcript_client: aws_sdk_transcribestreaming::Client,
69
+ mut rx: Receiver<Vec<u8>>,
70
+ tx: Sender<ReplyEvent>) -> Result<(), StreamTranscriptionError> {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  let input_stream = stream! {
73
  while let Some(raw) = rx.recv().await {
 
75
  }
76
  };
77
 
78
+ let output = transcript_client
79
  .start_stream_transcription()
80
  .language_code(LanguageCode::ZhCn)//LanguageCode::EnGb
81
  .media_sample_rate_hertz(16000)
82
  .media_encoding(MediaEncoding::Pcm)
83
  .audio_stream(input_stream.into())
84
  .send()
85
+ .await
86
+ .map_err(|e| StreamTranscriptionError::EstablishStreamError(Box::new(e)))?;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ let output_stream = to_stream(output);
 
 
 
 
 
89
 
90
+ output_stream
91
+ .flat_map(|res| {
92
+ process(translate_client.clone(), polly_client.clone(), res)
93
+ })
94
+ .try_for_each(|reply| tx.send(reply).map_err(|e| StreamTranscriptionError::Shutdown))
95
+ .await
96
  }
97
 
98
  async fn translate(client: &aws_sdk_translate::Client, transcript: Option<String>, source_lang_code: Option<String>) -> Option<String> {
 
116
  Some(bs.to_vec())
117
  }
118
 
119
+ #[derive(Clone)]
120
+ struct Context {
121
+ translate_client: aws_sdk_translate::Client,
122
+ polly_client: aws_sdk_polly::Client,
123
+ transcript_client: aws_sdk_transcribestreaming::Client,
124
+ }
125
+
126
+ struct Lessons {
127
+ lessons: BTreeMap<u32, Lesson>
128
+ }
129
+
130
+ struct Lesson {
131
 
 
 
 
 
 
132
  }
133
 
134
  #[tokio::main]
135
+ async fn main() -> Result<(), std::io::Error> {
136
+ tracing_subscriber::fmt::init();
137
+
138
+ let Opt {
139
+ region,
140
+ verbose,
141
+ } = Opt::parse();
142
+
143
+ let region_provider = RegionProviderChain::first_try(region.map(Region::new))
144
+ .or_default_provider()
145
+ .or_else(Region::new("us-west-2"));
146
+
147
+ println!();
148
+
149
+ if verbose {
150
+ println!("Transcribe client version: {}", PKG_VERSION);
151
+ println!(
152
+ "Region: {}",
153
+ region_provider.region().await.unwrap().as_ref()
154
+ );
155
+ // println!("Audio filename: {}", &audio_file);
156
+ println!();
157
+ }
158
+
159
+ let shared_config = aws_config::from_env().region(region_provider).load().await;
160
+ let transcript_client = aws_sdk_transcribestreaming::Client::new(&shared_config);
161
+ let translate_client = aws_sdk_translate::Client::new(&shared_config);
162
+ let polly_client = aws_sdk_polly::Client::new(&shared_config);
163
+ let ctx = Context {
164
+ translate_client,
165
+ polly_client,
166
+ transcript_client,
167
+ };
168
+
169
+ let app = Route::new().nest(
170
  "/",
171
  StaticFilesEndpoint::new("./static")
172
  .show_files_listing()
173
  .index_file("index.html"),
174
+ ).at("/translate", get(stream_translate))
175
+ .data(ctx);
176
  let listener = TcpListener::bind("[::]:8080");
177
  let server = Server::new(listener);
178
 
 
181
 
182
 
183
  #[handler]
184
+ async fn stream_translate(ctx: Data<&Context>, ws: WebSocket) -> impl IntoResponse {
185
+ let translate_client = ctx.translate_client.clone();
186
+ let polly_client = ctx.polly_client.clone();
187
+ let transcript_client = ctx.transcript_client.clone();
188
  ws.on_upgrade(|mut socket| async move {
189
  let (origin_tx, origin_rx) = channel::<Vec<u8>>(128);
190
  let (translate_tx, mut translate_rx) = channel::<ReplyEvent>(128);
191
+ let stream_fut = stream_process(
192
+ translate_client,
193
+ polly_client,
194
+ transcript_client,
195
+ origin_rx,
196
+ translate_tx);
197
 
198
  let ws_fut = async {
199
  loop {
 
244
  })
245
  }
246
 
247
+ #[derive(Debug)]
248
+ enum StreamTranscriptionError {
249
+ EstablishStreamError(Box<dyn Error + Send + Sync>),
250
+ TranscriptResultStreamError(Box<dyn Error + Send + Sync>),
251
+ Shutdown,
252
+ Unknown
253
+ }
254
+
255
+
256
+ impl Display for StreamTranscriptionError {
257
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
258
+ match self {
259
+ StreamTranscriptionError::EstablishStreamError(e) => write!(f, "EstablishStreamError: {}", e),
260
+ StreamTranscriptionError::TranscriptResultStreamError(e) => write!(f, "TranscriptResultStreamError: {}", e),
261
+ StreamTranscriptionError::Shutdown => write!(f, "Shutdown"),
262
+ StreamTranscriptionError::Unknown => write!(f, "Unknown"),
263
+ }
264
+ }
265
+ }
266
+
267
+ impl Error for StreamTranscriptionError {
268
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
269
+ match self {
270
+ StreamTranscriptionError::EstablishStreamError(e) => Some(e.as_ref()),
271
+ StreamTranscriptionError::TranscriptResultStreamError(e) => Some(e.as_ref()),
272
+ StreamTranscriptionError::Shutdown => None,
273
+ StreamTranscriptionError::Unknown => None,
274
+ }
275
+ }
276
+ }
277
+
278
+ fn to_stream(mut output: StartStreamTranscriptionOutput) -> impl Stream<Item=Result<String, StreamTranscriptionError>> {
279
+ stream! {
280
+ while let Some(event) = output
281
+ .transcript_result_stream
282
+ .recv()
283
+ .await
284
+ .map_err(|e| StreamTranscriptionError::TranscriptResultStreamError(Box::new(e)))? {
285
+ match event {
286
+ TranscriptResultStream::TranscriptEvent(transcript_event) => {
287
+ let transcript = transcript_event.transcript.expect("transcript");
288
+ for result in transcript.results.unwrap_or_default() {
289
+ if !result.is_partial {
290
+ let first_alternative = &result.alternatives.as_ref().expect("should have")[0];
291
+ let slice = first_alternative.transcript.as_ref().expect("should have");
292
+ yield Ok(slice.clone());
293
+ }
294
+ }
295
+ }
296
+ otherwise => yield Err(StreamTranscriptionError::Unknown),
297
+ }
298
+ }
299
+ }
300
+ }
301
+
302
+ fn process(translate_client: aws_sdk_translate::Client,
303
+ polly_client: aws_sdk_polly::Client,
304
+ res: Result<String, StreamTranscriptionError>) -> impl Stream<Item=Result<ReplyEvent, StreamTranscriptionError>> {
305
+ stream! {
306
+ match res {
307
+ Ok(transcription) => {
308
+ yield Ok(ReplyEvent::Transcribed(transcription.clone()));
309
+ let translated = translate(&translate_client, Some(transcription), Some("en".to_string())).await;
310
+ if let Some(has) = translated {
311
+ yield Ok(ReplyEvent::Translated(has.clone()));
312
+ println!("Translated: {}", has);
313
+ if let Some(synthesized) = synthesize(&polly_client, has).await {
314
+ yield Ok(ReplyEvent::Synthesized(synthesized));
315
+ }
316
+ }
317
+ },
318
+ Err(e) => {
319
+ yield Err(e);
320
+ }
321
+ }
322
+
323
+ }
324
+ }