mingyang91 commited on
Commit
5b9ecd0
·
verified ·
1 Parent(s): 55b9ff7
Files changed (4) hide show
  1. config.yaml +1 -1
  2. src/config.rs +7 -6
  3. src/main.rs +4 -4
  4. src/whisper.rs +117 -89
config.yaml CHANGED
@@ -5,7 +5,7 @@ whisper:
5
  # n_threads: 4
6
  length_ms: 5000
7
  keep_ms: 200
8
- step_ms: 500
9
  max_tokens: 0
10
  audio_ctx: 0
11
  vad_thold: 0.6
 
5
  # n_threads: 4
6
  length_ms: 5000
7
  keep_ms: 200
8
+ step_ms: 5000
9
  max_tokens: 0
10
  audio_ctx: 0
11
  vad_thold: 0.6
src/config.rs CHANGED
@@ -1,5 +1,4 @@
1
  use std::ffi::c_int;
2
- use std::fs;
3
  use std::net::IpAddr;
4
  use lazy_static::lazy_static;
5
  use serde::{Deserialize};
@@ -87,9 +86,11 @@ pub struct Config {
87
  pub(crate) server: Server,
88
  }
89
 
90
- #[tokio::test]
91
- async fn load() {
92
- let config_str = fs::read_to_string("config.yaml").expect("failed to read config file");
93
- let params: Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
94
- println!("{:?}", params);
 
 
95
  }
 
1
  use std::ffi::c_int;
 
2
  use std::net::IpAddr;
3
  use lazy_static::lazy_static;
4
  use serde::{Deserialize};
 
86
  pub(crate) server: Server,
87
  }
88
 
89
+ mod tests {
90
+ #[tokio::test]
91
+ async fn load() {
92
+ let config_str = tokio::fs::read_to_string("config.yaml").await.expect("failed to read config file");
93
+ let params: crate::config::Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
94
+ println!("{:?}", params);
95
+ }
96
  }
src/main.rs CHANGED
@@ -115,10 +115,10 @@ async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, w
115
  match msg.as_ref() {
116
  Some(Ok(Message::Binary(bin))) => {
117
  let _ = whisper.send(bin.to_vec()).await; // whisper test
118
- if let Err(e) = origin_tx.send(bin.to_vec()).await {
119
- tracing::warn!("failed to send voice: {}", e);
120
- break;
121
- }
122
  },
123
  Some(Ok(_)) => {
124
  tracing::warn!("Other: {:?}", msg);
 
115
  match msg.as_ref() {
116
  Some(Ok(Message::Binary(bin))) => {
117
  let _ = whisper.send(bin.to_vec()).await; // whisper test
118
+ // if let Err(e) = origin_tx.send(bin.to_vec()).await {
119
+ // tracing::warn!("failed to send voice: {}", e);
120
+ // break;
121
+ // }
122
  },
123
  Some(Ok(_)) => {
124
  tracing::warn!("Other: {:?}", msg);
src/whisper.rs CHANGED
@@ -5,7 +5,7 @@ use std::thread::sleep;
5
  use std::time::Duration;
6
  use lazy_static::lazy_static;
7
  use tokio::sync::{broadcast, mpsc, oneshot};
8
- use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext, FullParams};
9
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
10
  use crate::config::{WhisperParams, CONFIG};
11
  use crate::group::GroupedWithin;
@@ -79,28 +79,20 @@ pub struct WhisperHandler {
79
 
80
  impl WhisperHandler {
81
  pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
82
- let n_samples_step: usize = (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
83
- let n_samples_len: usize = (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
84
- let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
85
- let n_new_line: usize = 1.max(config.length_ms / config.step_ms - 1) as usize;
86
- let mut n_iter: usize = 0;
87
-
88
  let (stop_handle, mut stop_signal) = oneshot::channel();
89
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
90
- let mut grouped = GroupedWithin::new(
91
- n_samples_step * 2,
92
- Duration::from_millis(config.step_ms as u64),
93
- pcm_rx,
94
- u16::MAX as usize
95
- );
96
  let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
97
  let shared_transcription_tx = transcription_tx.clone();
98
- let mut state = WHISPER_CONTEXT.create_state()
99
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
100
  tokio::task::spawn_blocking(move || {
101
- let mut prompt_tokens: Vec<c_int> = Default::default();
102
- let mut pcm_f32: VecDeque<f32> = VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]);
103
- let mut last_offset: usize = 0;
 
 
 
 
104
  while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
105
  let new_pcm_f32 = match grouped.next() {
106
  Err(mpsc::error::TryRecvError::Disconnected) => break,
@@ -113,16 +105,9 @@ impl WhisperHandler {
113
  }
114
  };
115
 
116
- let params = config.to_full_params(prompt_tokens.as_slice());
117
- pcm_f32.extend(new_pcm_f32);
118
- if pcm_f32.len() > n_samples_len + n_samples_keep {
119
- let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep - n_samples_len)).len();
120
- }
121
- pcm_f32.make_contiguous();
122
- let (data, _) = pcm_f32.as_slices();
123
- let segments = match inference(&mut state, params, data, 0) {
124
- Ok((offset, result)) => {
125
- last_offset = offset;
126
  if result.is_empty() {
127
  continue
128
  }
@@ -134,19 +119,18 @@ impl WhisperHandler {
134
  }
135
  };
136
 
137
- n_iter = n_iter + 1;
138
-
139
- if n_iter % n_new_line == 0 {
140
- if let Err(e) = new_line(&mut pcm_f32, n_samples_keep, &config, &mut state, segments.len(), &mut prompt_tokens) {
141
- tracing::warn!("failed to new_line: {}", e);
142
  }
143
- tracing::debug!("LINE: {}", segments.first().unwrap().text);
144
-
145
- if let Err(e) = shared_transcription_tx.send(segments) {
146
- tracing::error!("failed to send transcription: {}", e);
147
- break
148
- };
149
  }
 
 
 
 
 
150
  }
151
  });
152
  Ok(Self {
@@ -165,66 +149,110 @@ impl WhisperHandler {
165
  }
166
  }
167
 
168
- fn inference(
169
- state: &mut WhisperState,
170
- params: FullParams,
171
- pcm_f32: &[f32],
172
- mut offset: usize,
173
- ) -> Result<(usize, Vec<Segment>), Error> {
174
-
175
- let _ = state.full(params, pcm_f32)
176
- .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
177
-
178
- let num_segments = state
179
- .full_n_segments()
180
- .map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
181
- let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
182
- for i in 0..num_segments {
183
- let segment = state
184
- .full_get_segment_text(i)
185
- .map_err(|e| Error::whisper_error("failed to get segment", e))?;
186
- let timestamp_offset: i64 = (offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
187
- let start_timestamp: i64 = timestamp_offset + 10 * state
188
- .full_get_segment_t0(i)
189
- .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
190
- let end_timestamp: i64 = timestamp_offset + 10 * state
191
- .full_get_segment_t1(i)
192
- .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
193
- // tracing::trace!("{}", segment);
194
- segments.push(Segment { start_timestamp, end_timestamp, text: segment });
 
195
  }
196
 
197
- Ok((offset, segments))
198
- }
 
 
 
 
199
 
200
- fn new_line(pcm_f32: &mut VecDeque<f32>,
201
- n_samples_keep: usize,
202
- config: &WhisperParams,
203
- state: &mut WhisperState,
204
- num_segments: usize,
205
- prompt_tokens: &mut Vec<c_int>) -> Result<(), Error> {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- // keep the last n_samples_keep samples from pcm_f32
208
- if pcm_f32.len() > n_samples_keep {
209
- let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep)).len();
 
 
 
 
 
 
210
  }
211
 
212
- if !config.no_context {
213
- prompt_tokens.clear();
214
-
215
- for i in 0..num_segments as c_int {
216
- let token_count = state
217
- .full_n_tokens(i)
218
- .map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
219
- for j in 0..token_count {
220
- let token = state
221
- .full_get_token_id(i, j)
222
- .map_err(|e| Error::whisper_error("failed to get token", e))?;
223
- prompt_tokens.push(token);
 
 
 
 
 
 
 
 
 
 
 
224
  }
225
  }
 
226
  }
227
- Ok(())
228
  }
229
 
230
  impl Drop for WhisperHandler {
 
5
  use std::time::Duration;
6
  use lazy_static::lazy_static;
7
  use tokio::sync::{broadcast, mpsc, oneshot};
8
+ use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
9
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
10
  use crate::config::{WhisperParams, CONFIG};
11
  use crate::group::GroupedWithin;
 
79
 
80
  impl WhisperHandler {
81
  pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
 
 
 
 
 
 
82
  let (stop_handle, mut stop_signal) = oneshot::channel();
83
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
 
 
 
 
 
 
84
  let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
85
  let shared_transcription_tx = transcription_tx.clone();
86
+ let state = WHISPER_CONTEXT.create_state()
87
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
88
  tokio::task::spawn_blocking(move || {
89
+ let mut detector = Detector::new(state, &CONFIG.whisper);
90
+ let mut grouped = GroupedWithin::new(
91
+ detector.n_samples_step * 2,
92
+ Duration::from_millis(config.step_ms as u64),
93
+ pcm_rx,
94
+ u16::MAX as usize
95
+ );
96
  while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
97
  let new_pcm_f32 = match grouped.next() {
98
  Err(mpsc::error::TryRecvError::Disconnected) => break,
 
105
  }
106
  };
107
 
108
+ detector.feed(new_pcm_f32);
109
+ let segments = match detector.inference() {
110
+ Ok(result) => {
 
 
 
 
 
 
 
111
  if result.is_empty() {
112
  continue
113
  }
 
119
  }
120
  };
121
 
122
+ if tracing::enabled!(tracing::Level::TRACE) {
123
+ for segment in segments.iter() {
124
+ tracing::trace!("[{}] SEGMENT: {}", detector.n_iter, segment.text);
 
 
125
  }
126
+ } else if tracing::enabled!(tracing::Level::DEBUG) {
127
+ tracing::debug!("[{}] SEGMENT: {}", detector.n_iter, segments[0].text);
 
 
 
 
128
  }
129
+
130
+ if let Err(e) = shared_transcription_tx.send(segments) {
131
+ tracing::error!("failed to send transcription: {}", e);
132
+ break
133
+ };
134
  }
135
  });
136
  Ok(Self {
 
149
  }
150
  }
151
 
152
+ struct Detector {
153
+ state: WhisperState<'static>,
154
+ config: &'static WhisperParams,
155
+ n_samples_keep: usize,
156
+ n_samples_step: usize,
157
+ n_samples_len: usize,
158
+ n_new_line: usize,
159
+ n_iter: usize,
160
+ prompt_tokens: Vec<c_int>,
161
+ pcm_f32: VecDeque<f32>,
162
+ offset: usize,
163
+ }
164
+
165
+ impl Detector {
166
+ fn new(state: WhisperState<'static>,
167
+ config: &'static WhisperParams) -> Self {
168
+ Detector {
169
+ state,
170
+ config,
171
+ n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
172
+ n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
173
+ n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
174
+ n_new_line: 1.max(config.length_ms / config.step_ms - 1) as usize,
175
+ n_iter: 0,
176
+ prompt_tokens: Default::default(),
177
+ pcm_f32: VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]),
178
+ offset: 0,
179
+ }
180
  }
181
 
182
+ fn feed(&mut self, new_pcm_f32: Vec<f32>) {
183
+ self.pcm_f32.extend(new_pcm_f32);
184
+ if self.pcm_f32.len() > self.n_samples_len + self.n_samples_keep {
185
+ let _ = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep - self.n_samples_len)).len();
186
+ }
187
+ }
188
 
189
+ fn inference(&mut self) -> Result<Vec<Segment>, Error> {
190
+ let params = self.config.to_full_params(self.prompt_tokens.as_slice());
191
+ let start = std::time::Instant::now();
192
+ let _ = self.state.full(params, self.pcm_f32.make_contiguous())
193
+ .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
194
+ let end = std::time::Instant::now();
195
+ if end - start > Duration::from_millis(self.config.step_ms as u64) {
196
+ tracing::warn!("full() took {} ms too slow", (end - start).as_millis());
197
+ }
198
+
199
+ let num_segments = self.state
200
+ .full_n_segments()
201
+ .map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
202
+ let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
203
+ for i in 0..num_segments {
204
+ let segment = self.state
205
+ .full_get_segment_text(i)
206
+ .map_err(|e| Error::whisper_error("failed to get segment", e))?;
207
+ let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
208
+ let start_timestamp: i64 = timestamp_offset + 10 * self.state
209
+ .full_get_segment_t0(i)
210
+ .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
211
+ let end_timestamp: i64 = timestamp_offset + 10 * self.state
212
+ .full_get_segment_t1(i)
213
+ .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
214
+ // tracing::trace!("{}", segment);
215
+ segments.push(Segment { start_timestamp, end_timestamp, text: segment });
216
+ }
217
 
218
+
219
+ self.n_iter = self.n_iter + 1;
220
+
221
+ if self.n_iter % self.n_new_line == 0 {
222
+ self.next_line()?;
223
+ Ok(segments)
224
+ } else {
225
+ Ok(vec![])
226
+ }
227
  }
228
 
229
+ fn next_line(&mut self) -> Result<(), Error> {
230
+
231
+ // keep the last n_samples_keep samples from pcm_f32
232
+ if self.pcm_f32.len() > self.n_samples_keep {
233
+ let _ = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep)).len();
234
+ }
235
+
236
+ if !self.config.no_context {
237
+ self.prompt_tokens.clear();
238
+
239
+ let num_segments = self.state
240
+ .full_n_segments()
241
+ .map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
242
+ for i in 0..num_segments {
243
+ let token_count = self.state
244
+ .full_n_tokens(i)
245
+ .map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
246
+ for j in 0..token_count {
247
+ let token = self.state
248
+ .full_get_token_id(i, j)
249
+ .map_err(|e| Error::whisper_error("failed to get token", e))?;
250
+ self.prompt_tokens.push(token);
251
+ }
252
  }
253
  }
254
+ Ok(())
255
  }
 
256
  }
257
 
258
  impl Drop for WhisperHandler {