mingyang91 commited on
Commit
f33df9c
·
verified ·
1 Parent(s): fb5758b

clean code

Browse files
Files changed (3) hide show
  1. src/asr/whisper.rs +18 -31
  2. src/main.rs +11 -2
  3. whisper/src/handler.rs +46 -31
src/asr/whisper.rs CHANGED
@@ -1,8 +1,7 @@
1
  use std::fmt::{Debug, Formatter};
2
  use async_trait::async_trait;
3
- use tokio::{select, spawn};
4
  use tokio::sync::broadcast::Receiver;
5
- use tokio::sync::broadcast::error::RecvError;
6
  use lazy_static::lazy_static;
7
 
8
  extern crate whisper;
@@ -29,39 +28,27 @@ impl Debug for WhisperAsr {
29
 
30
  impl WhisperAsr {
31
  pub async fn from_config() -> Result<WhisperAsr, Error> {
32
- let whisper = CONTEXT.create_handler(&SETTINGS.whisper, "".to_string())?;
33
  let mut output_rx = whisper.subscribe();
34
  let (tx, _) = tokio::sync::broadcast::channel(64);
35
  let shared_tx = tx.clone();
36
  let fut = async move {
37
- loop {
38
- select! {
39
- poll = output_rx.recv() => {
40
- match poll {
41
- Ok(outputs) => {
42
- for output in outputs {
43
- let res = match output {
44
- Output::Stable(segment) => tx.send(Event {
45
- transcript: segment.text,
46
- is_final: true,
47
- }),
48
- Output::Unstable(segment) => tx.send(Event {
49
- transcript: segment.text,
50
- is_final: false,
51
- }),
52
- };
53
- if let Err(e) = res {
54
- tracing::warn!("Failed to send whisper event: {}", e);
55
- break
56
- }
57
- }
58
- },
59
- Err(RecvError::Closed) => break,
60
- Err(RecvError::Lagged(lagged)) => {
61
- tracing::warn!("Whisper ASR output lagged: {}", lagged);
62
- }
63
- }
64
- },
65
  }
66
  }
67
  };
 
1
  use std::fmt::{Debug, Formatter};
2
  use async_trait::async_trait;
3
+ use tokio::{spawn};
4
  use tokio::sync::broadcast::Receiver;
 
5
  use lazy_static::lazy_static;
6
 
7
  extern crate whisper;
 
28
 
29
  impl WhisperAsr {
30
  pub async fn from_config() -> Result<WhisperAsr, Error> {
31
+ let whisper = CONTEXT.create_handler(SETTINGS.whisper.clone(), "".to_string());
32
  let mut output_rx = whisper.subscribe();
33
  let (tx, _) = tokio::sync::broadcast::channel(64);
34
  let shared_tx = tx.clone();
35
  let fut = async move {
36
+ while let Ok(outputs) = output_rx.recv().await {
37
+ for output in outputs {
38
+ let evt = match output {
39
+ Output::Stable(segment) => Event {
40
+ transcript: segment.text,
41
+ is_final: true,
42
+ },
43
+ Output::Unstable(segment) => Event {
44
+ transcript: segment.text,
45
+ is_final: false,
46
+ },
47
+ };
48
+ if let Err(e) = tx.send(evt) {
49
+ tracing::warn!("Failed to send whisper event: {}", e);
50
+ break
51
+ }
 
 
 
 
 
 
 
 
 
 
 
 
52
  }
53
  }
54
  };
src/main.rs CHANGED
@@ -24,7 +24,7 @@ use poem::{
24
  };
25
  use serde::{Deserialize, Serialize};
26
  use tokio::select;
27
- use tracing::debug;
28
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
29
 
30
  use crate::{config::*, lesson::*};
@@ -122,6 +122,9 @@ async fn stream_speaker(
122
  msg = socket.next() => {
123
  let Some(res) = msg else { break };
124
  let msg = res?;
 
 
 
125
  let Message::Binary(bin) = msg else {
126
  tracing::warn!("Other: {:?}", msg);
127
  continue
@@ -131,7 +134,9 @@ async fn stream_speaker(
131
  },
132
  output = transcribe_rx.recv() => {
133
  let evt = output?;
134
- tracing::trace!("Transcribed: {}", evt.transcript);
 
 
135
  let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
136
  let Ok(json) = serde_json::to_string(&evt) else {
137
  tracing::warn!("failed to serialize json: {:?}", evt);
@@ -144,6 +149,8 @@ async fn stream_speaker(
144
  Ok(())
145
  };
146
 
 
 
147
  let res: anyhow::Result<()> = fut.await;
148
  match res {
149
  Ok(()) => {
@@ -252,6 +259,8 @@ async fn stream_listener(
252
  }
253
  };
254
 
 
 
255
  let res: anyhow::Result<()> = fut.await;
256
  match res {
257
  Ok(()) => {
 
24
  };
25
  use serde::{Deserialize, Serialize};
26
  use tokio::select;
27
+ use tracing::{debug, span};
28
  use tracing_subscriber::{fmt, prelude::*, EnvFilter};
29
 
30
  use crate::{config::*, lesson::*};
 
122
  msg = socket.next() => {
123
  let Some(res) = msg else { break };
124
  let msg = res?;
125
+ if msg.is_close() {
126
+ break
127
+ }
128
  let Message::Binary(bin) = msg else {
129
  tracing::warn!("Other: {:?}", msg);
130
  continue
 
134
  },
135
  output = transcribe_rx.recv() => {
136
  let evt = output?;
137
+ if evt.is_final {
138
+ tracing::trace!("Transcribed: {}", evt.transcript);
139
+ }
140
  let evt = LiveLessonTextEvent::Transcription { content: evt.transcript, is_final: evt.is_final };
141
  let Ok(json) = serde_json::to_string(&evt) else {
142
  tracing::warn!("failed to serialize json: {:?}", evt);
 
149
  Ok(())
150
  };
151
 
152
+ let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
153
+ let _ = span.enter();
154
  let res: anyhow::Result<()> = fut.await;
155
  match res {
156
  Ok(()) => {
 
259
  }
260
  };
261
 
262
+ let span = span!(tracing::Level::TRACE, "lesson_listener", lesson_id = query.id);
263
+ let _ = span.enter();
264
  let res: anyhow::Result<()> = fut.await;
265
  match res {
266
  Ok(()) => {
whisper/src/handler.rs CHANGED
@@ -3,7 +3,9 @@ use std::{
3
  fmt::{Debug, Display, Formatter},
4
  thread::sleep,
5
  time::Duration,
 
6
  };
 
7
  use fvad::SampleRate;
8
 
9
  use tokio::sync::{broadcast, mpsc, oneshot};
@@ -15,18 +17,27 @@ use crate::{config::WhisperConfig, group::GroupedWithin};
15
 
16
  const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
17
 
 
18
  pub struct Context {
19
- context: WhisperContext,
20
  }
21
 
22
- impl <'a> Context {
23
  pub fn new(model: &str) -> Result<Context, WhisperError> {
24
  WhisperContext::new(model)
25
- .map(|context| Self { context })
26
  }
27
 
28
- pub fn create_handler(&'static self, config: &'static WhisperConfig, prompt: String) -> Result<WhisperHandler, Error> {
29
- WhisperHandler::new(&self.context, config, prompt)
 
 
 
 
 
 
 
 
30
  }
31
  }
32
 
@@ -87,22 +98,30 @@ pub struct WhisperHandler {
87
  }
88
 
89
  impl WhisperHandler {
90
-
91
- fn new(whisper_context: &'static WhisperContext, config: &'static WhisperConfig, prompt: String) -> Result<Self, Error> {
92
- // let whisper_context = get_whisper_context(config.model.clone()).await;
93
- let state = whisper_context
94
- .create_state()
95
- .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
96
- let preset_prompt_tokens = whisper_context
97
- .tokenize(&prompt, config.max_prompt_tokens)
98
- .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
99
  let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
100
  let (stop_handle, mut stop_signal) = oneshot::channel();
101
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
102
  let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
103
  let shared_transcription_tx = transcription_tx.clone();
104
 
105
- tokio::task::spawn_blocking(move || {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  let mut vad = fvad::Fvad::new().expect("failed to create VAD")
107
  .set_sample_rate(SampleRate::Rate16kHz);
108
  let mut detector = Detector::new(state, &config, preset_prompt_tokens);
@@ -112,6 +131,7 @@ impl WhisperHandler {
112
  pcm_rx,
113
  u16::MAX as usize,
114
  );
 
115
  while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
116
  if detector.has_crossed_next_line() {
117
  if let Some(segment) = detector.next_line() {
@@ -137,7 +157,6 @@ impl WhisperHandler {
137
  } else {
138
  vad.is_voice_frame(frame).unwrap_or(true)
139
  }
140
- // true
141
  })
142
  .collect::<Vec<_>>()
143
  .concat();
@@ -168,13 +187,15 @@ impl WhisperHandler {
168
  break;
169
  };
170
  }
171
- });
 
 
172
 
173
- Ok(Self {
174
  tx: pcm_tx,
175
  transcription_tx,
176
  stop_handle: Some(stop_handle),
177
- })
178
  }
179
 
180
  pub fn subscribe(&self) -> broadcast::Receiver<Vec<Output>> {
@@ -187,9 +208,9 @@ impl WhisperHandler {
187
  }
188
 
189
  #[allow(dead_code)]
190
- struct Detector {
191
- state: WhisperState<'static>,
192
- config: &'static WhisperConfig,
193
  start_time: Instant,
194
  segment: Option<Segment>,
195
  line_num: usize,
@@ -202,10 +223,10 @@ struct Detector {
202
  offset: usize,
203
  }
204
 
205
- impl Detector {
206
  fn new(
207
- state: WhisperState<'static>,
208
- config: &'static WhisperConfig,
209
  preset_prompt_tokens: Vec<WhisperToken>,
210
  ) -> Self {
211
  Detector {
@@ -229,12 +250,6 @@ impl Detector {
229
  if self.pcm_f32.len() < self.n_samples_len {
230
  return;
231
  }
232
- // let len_to_drain = self
233
- // .pcm_f32
234
- // .drain(0..(self.pcm_f32.len() - self.n_samples_len))
235
- // .len();
236
- // warn!("ASR too slow, drain {} samples", len_to_drain);
237
- // self.offset += len_to_drain;
238
  }
239
 
240
  fn inference(&mut self) -> Result<Vec<Segment>, Error> {
 
3
  fmt::{Debug, Display, Formatter},
4
  thread::sleep,
5
  time::Duration,
6
+ sync::Arc
7
  };
8
+ use std::ops::Deref;
9
  use fvad::SampleRate;
10
 
11
  use tokio::sync::{broadcast, mpsc, oneshot};
 
17
 
18
  const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
19
 
20
+ #[derive(Clone)]
21
  pub struct Context {
22
+ inner: Arc<WhisperContext>,
23
  }
24
 
25
+ impl Context {
26
  pub fn new(model: &str) -> Result<Context, WhisperError> {
27
  WhisperContext::new(model)
28
+ .map(|context| Self { inner: Arc::new(context) })
29
  }
30
 
31
+ pub fn create_handler(&self, config: WhisperConfig, prompt: String) -> WhisperHandler {
32
+ WhisperHandler::new(self.clone(), config, prompt)
33
+ }
34
+ }
35
+
36
+ impl Deref for Context {
37
+ type Target = WhisperContext;
38
+
39
+ fn deref(&self) -> &Self::Target {
40
+ &self.inner
41
  }
42
  }
43
 
 
98
  }
99
 
100
  impl WhisperHandler {
101
+ fn new(whisper_context: Context, config: WhisperConfig, prompt: String) -> Self {
 
 
 
 
 
 
 
 
102
  let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
103
  let (stop_handle, mut stop_signal) = oneshot::channel();
104
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
105
  let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
106
  let shared_transcription_tx = transcription_tx.clone();
107
 
108
+ let task = move || {
109
+ let state = match whisper_context.create_state() {
110
+ Ok(state) => state,
111
+ Err(e) => {
112
+ tracing::error!("failed to create WhisperState: {}", e);
113
+ return Err(Error::whisper_error("failed to create WhisperState", e))
114
+ }
115
+ };
116
+
117
+ let preset_prompt_tokens = match whisper_context.tokenize(&prompt, config.max_prompt_tokens) {
118
+ Ok(tokens) => tokens,
119
+ Err(e) => {
120
+ tracing::error!("failed to tokenize prompt: {}", e);
121
+ return Err(Error::whisper_error("failed to tokenize prompt", e))
122
+ }
123
+ };
124
+
125
  let mut vad = fvad::Fvad::new().expect("failed to create VAD")
126
  .set_sample_rate(SampleRate::Rate16kHz);
127
  let mut detector = Detector::new(state, &config, preset_prompt_tokens);
 
131
  pcm_rx,
132
  u16::MAX as usize,
133
  );
134
+
135
  while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
136
  if detector.has_crossed_next_line() {
137
  if let Some(segment) = detector.next_line() {
 
157
  } else {
158
  vad.is_voice_frame(frame).unwrap_or(true)
159
  }
 
160
  })
161
  .collect::<Vec<_>>()
162
  .concat();
 
187
  break;
188
  };
189
  }
190
+ Ok::<(), Error>(())
191
+ };
192
+ tokio::task::spawn_blocking(task);
193
 
194
+ Self {
195
  tx: pcm_tx,
196
  transcription_tx,
197
  stop_handle: Some(stop_handle),
198
+ }
199
  }
200
 
201
  pub fn subscribe(&self) -> broadcast::Receiver<Vec<Output>> {
 
208
  }
209
 
210
  #[allow(dead_code)]
211
+ struct Detector<'a> {
212
+ state: WhisperState<'a>,
213
+ config: &'a WhisperConfig,
214
  start_time: Instant,
215
  segment: Option<Segment>,
216
  line_num: usize,
 
223
  offset: usize,
224
  }
225
 
226
+ impl <'a> Detector<'a> {
227
  fn new(
228
+ state: WhisperState<'a>,
229
+ config: &'a WhisperConfig,
230
  preset_prompt_tokens: Vec<WhisperToken>,
231
  ) -> Self {
232
  Detector {
 
250
  if self.pcm_f32.len() < self.n_samples_len {
251
  return;
252
  }
 
 
 
 
 
 
253
  }
254
 
255
  fn inference(&mut self) -> Result<Vec<Segment>, Error> {