Spaces:
Sleeping
Sleeping
WHISPER!
Browse files- config.yaml +1 -1
- src/config.rs +7 -6
- src/main.rs +4 -4
- src/whisper.rs +117 -89
config.yaml
CHANGED
@@ -5,7 +5,7 @@ whisper:
|
|
5 |
# n_threads: 4
|
6 |
length_ms: 5000
|
7 |
keep_ms: 200
|
8 |
-
step_ms:
|
9 |
max_tokens: 0
|
10 |
audio_ctx: 0
|
11 |
vad_thold: 0.6
|
|
|
5 |
# n_threads: 4
|
6 |
length_ms: 5000
|
7 |
keep_ms: 200
|
8 |
+
step_ms: 5000
|
9 |
max_tokens: 0
|
10 |
audio_ctx: 0
|
11 |
vad_thold: 0.6
|
src/config.rs
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
use std::ffi::c_int;
|
2 |
-
use std::fs;
|
3 |
use std::net::IpAddr;
|
4 |
use lazy_static::lazy_static;
|
5 |
use serde::{Deserialize};
|
@@ -87,9 +86,11 @@ pub struct Config {
|
|
87 |
pub(crate) server: Server,
|
88 |
}
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
}
|
|
|
1 |
use std::ffi::c_int;
|
|
|
2 |
use std::net::IpAddr;
|
3 |
use lazy_static::lazy_static;
|
4 |
use serde::{Deserialize};
|
|
|
86 |
pub(crate) server: Server,
|
87 |
}
|
88 |
|
89 |
+
mod tests {
|
90 |
+
#[tokio::test]
|
91 |
+
async fn load() {
|
92 |
+
let config_str = tokio::fs::read_to_string("config.yaml").await.expect("failed to read config file");
|
93 |
+
let params: crate::config::Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
|
94 |
+
println!("{:?}", params);
|
95 |
+
}
|
96 |
}
|
src/main.rs
CHANGED
@@ -115,10 +115,10 @@ async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, w
|
|
115 |
match msg.as_ref() {
|
116 |
Some(Ok(Message::Binary(bin))) => {
|
117 |
let _ = whisper.send(bin.to_vec()).await; // whisper test
|
118 |
-
if let Err(e) = origin_tx.send(bin.to_vec()).await {
|
119 |
-
|
120 |
-
|
121 |
-
}
|
122 |
},
|
123 |
Some(Ok(_)) => {
|
124 |
tracing::warn!("Other: {:?}", msg);
|
|
|
115 |
match msg.as_ref() {
|
116 |
Some(Ok(Message::Binary(bin))) => {
|
117 |
let _ = whisper.send(bin.to_vec()).await; // whisper test
|
118 |
+
// if let Err(e) = origin_tx.send(bin.to_vec()).await {
|
119 |
+
// tracing::warn!("failed to send voice: {}", e);
|
120 |
+
// break;
|
121 |
+
// }
|
122 |
},
|
123 |
Some(Ok(_)) => {
|
124 |
tracing::warn!("Other: {:?}", msg);
|
src/whisper.rs
CHANGED
@@ -5,7 +5,7 @@ use std::thread::sleep;
|
|
5 |
use std::time::Duration;
|
6 |
use lazy_static::lazy_static;
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
-
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
use crate::config::{WhisperParams, CONFIG};
|
11 |
use crate::group::GroupedWithin;
|
@@ -79,28 +79,20 @@ pub struct WhisperHandler {
|
|
79 |
|
80 |
impl WhisperHandler {
|
81 |
pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
|
82 |
-
let n_samples_step: usize = (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
83 |
-
let n_samples_len: usize = (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
84 |
-
let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
85 |
-
let n_new_line: usize = 1.max(config.length_ms / config.step_ms - 1) as usize;
|
86 |
-
let mut n_iter: usize = 0;
|
87 |
-
|
88 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
89 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
90 |
-
let mut grouped = GroupedWithin::new(
|
91 |
-
n_samples_step * 2,
|
92 |
-
Duration::from_millis(config.step_ms as u64),
|
93 |
-
pcm_rx,
|
94 |
-
u16::MAX as usize
|
95 |
-
);
|
96 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
97 |
let shared_transcription_tx = transcription_tx.clone();
|
98 |
-
let
|
99 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
100 |
tokio::task::spawn_blocking(move || {
|
101 |
-
let mut
|
102 |
-
let mut
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
|
105 |
let new_pcm_f32 = match grouped.next() {
|
106 |
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
@@ -113,16 +105,9 @@ impl WhisperHandler {
|
|
113 |
}
|
114 |
};
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep - n_samples_len)).len();
|
120 |
-
}
|
121 |
-
pcm_f32.make_contiguous();
|
122 |
-
let (data, _) = pcm_f32.as_slices();
|
123 |
-
let segments = match inference(&mut state, params, data, 0) {
|
124 |
-
Ok((offset, result)) => {
|
125 |
-
last_offset = offset;
|
126 |
if result.is_empty() {
|
127 |
continue
|
128 |
}
|
@@ -134,19 +119,18 @@ impl WhisperHandler {
|
|
134 |
}
|
135 |
};
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
if let Err(e) = new_line(&mut pcm_f32, n_samples_keep, &config, &mut state, segments.len(), &mut prompt_tokens) {
|
141 |
-
tracing::warn!("failed to new_line: {}", e);
|
142 |
}
|
143 |
-
|
144 |
-
|
145 |
-
if let Err(e) = shared_transcription_tx.send(segments) {
|
146 |
-
tracing::error!("failed to send transcription: {}", e);
|
147 |
-
break
|
148 |
-
};
|
149 |
}
|
|
|
|
|
|
|
|
|
|
|
150 |
}
|
151 |
});
|
152 |
Ok(Self {
|
@@ -165,66 +149,110 @@ impl WhisperHandler {
|
|
165 |
}
|
166 |
}
|
167 |
|
168 |
-
|
169 |
-
state:
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
.
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
}
|
196 |
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
fn
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
}
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
let
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
}
|
225 |
}
|
|
|
226 |
}
|
227 |
-
Ok(())
|
228 |
}
|
229 |
|
230 |
impl Drop for WhisperHandler {
|
|
|
5 |
use std::time::Duration;
|
6 |
use lazy_static::lazy_static;
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
+
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
use crate::config::{WhisperParams, CONFIG};
|
11 |
use crate::group::GroupedWithin;
|
|
|
79 |
|
80 |
impl WhisperHandler {
|
81 |
pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
83 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
85 |
let shared_transcription_tx = transcription_tx.clone();
|
86 |
+
let state = WHISPER_CONTEXT.create_state()
|
87 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
88 |
tokio::task::spawn_blocking(move || {
|
89 |
+
let mut detector = Detector::new(state, &CONFIG.whisper);
|
90 |
+
let mut grouped = GroupedWithin::new(
|
91 |
+
detector.n_samples_step * 2,
|
92 |
+
Duration::from_millis(config.step_ms as u64),
|
93 |
+
pcm_rx,
|
94 |
+
u16::MAX as usize
|
95 |
+
);
|
96 |
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
|
97 |
let new_pcm_f32 = match grouped.next() {
|
98 |
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
|
|
105 |
}
|
106 |
};
|
107 |
|
108 |
+
detector.feed(new_pcm_f32);
|
109 |
+
let segments = match detector.inference() {
|
110 |
+
Ok(result) => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if result.is_empty() {
|
112 |
continue
|
113 |
}
|
|
|
119 |
}
|
120 |
};
|
121 |
|
122 |
+
if tracing::enabled!(tracing::Level::TRACE) {
|
123 |
+
for segment in segments.iter() {
|
124 |
+
tracing::trace!("[{}] SEGMENT: {}", detector.n_iter, segment.text);
|
|
|
|
|
125 |
}
|
126 |
+
} else if tracing::enabled!(tracing::Level::DEBUG) {
|
127 |
+
tracing::debug!("[{}] SEGMENT: {}", detector.n_iter, segments[0].text);
|
|
|
|
|
|
|
|
|
128 |
}
|
129 |
+
|
130 |
+
if let Err(e) = shared_transcription_tx.send(segments) {
|
131 |
+
tracing::error!("failed to send transcription: {}", e);
|
132 |
+
break
|
133 |
+
};
|
134 |
}
|
135 |
});
|
136 |
Ok(Self {
|
|
|
149 |
}
|
150 |
}
|
151 |
|
152 |
+
struct Detector {
|
153 |
+
state: WhisperState<'static>,
|
154 |
+
config: &'static WhisperParams,
|
155 |
+
n_samples_keep: usize,
|
156 |
+
n_samples_step: usize,
|
157 |
+
n_samples_len: usize,
|
158 |
+
n_new_line: usize,
|
159 |
+
n_iter: usize,
|
160 |
+
prompt_tokens: Vec<c_int>,
|
161 |
+
pcm_f32: VecDeque<f32>,
|
162 |
+
offset: usize,
|
163 |
+
}
|
164 |
+
|
165 |
+
impl Detector {
|
166 |
+
fn new(state: WhisperState<'static>,
|
167 |
+
config: &'static WhisperParams) -> Self {
|
168 |
+
Detector {
|
169 |
+
state,
|
170 |
+
config,
|
171 |
+
n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
172 |
+
n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
173 |
+
n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
174 |
+
n_new_line: 1.max(config.length_ms / config.step_ms - 1) as usize,
|
175 |
+
n_iter: 0,
|
176 |
+
prompt_tokens: Default::default(),
|
177 |
+
pcm_f32: VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]),
|
178 |
+
offset: 0,
|
179 |
+
}
|
180 |
}
|
181 |
|
182 |
+
fn feed(&mut self, new_pcm_f32: Vec<f32>) {
|
183 |
+
self.pcm_f32.extend(new_pcm_f32);
|
184 |
+
if self.pcm_f32.len() > self.n_samples_len + self.n_samples_keep {
|
185 |
+
let _ = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep - self.n_samples_len)).len();
|
186 |
+
}
|
187 |
+
}
|
188 |
|
189 |
+
fn inference(&mut self) -> Result<Vec<Segment>, Error> {
|
190 |
+
let params = self.config.to_full_params(self.prompt_tokens.as_slice());
|
191 |
+
let start = std::time::Instant::now();
|
192 |
+
let _ = self.state.full(params, self.pcm_f32.make_contiguous())
|
193 |
+
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
194 |
+
let end = std::time::Instant::now();
|
195 |
+
if end - start > Duration::from_millis(self.config.step_ms as u64) {
|
196 |
+
tracing::warn!("full() took {} ms too slow", (end - start).as_millis());
|
197 |
+
}
|
198 |
+
|
199 |
+
let num_segments = self.state
|
200 |
+
.full_n_segments()
|
201 |
+
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
202 |
+
let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
|
203 |
+
for i in 0..num_segments {
|
204 |
+
let segment = self.state
|
205 |
+
.full_get_segment_text(i)
|
206 |
+
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
207 |
+
let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
|
208 |
+
let start_timestamp: i64 = timestamp_offset + 10 * self.state
|
209 |
+
.full_get_segment_t0(i)
|
210 |
+
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
211 |
+
let end_timestamp: i64 = timestamp_offset + 10 * self.state
|
212 |
+
.full_get_segment_t1(i)
|
213 |
+
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
214 |
+
// tracing::trace!("{}", segment);
|
215 |
+
segments.push(Segment { start_timestamp, end_timestamp, text: segment });
|
216 |
+
}
|
217 |
|
218 |
+
|
219 |
+
self.n_iter = self.n_iter + 1;
|
220 |
+
|
221 |
+
if self.n_iter % self.n_new_line == 0 {
|
222 |
+
self.next_line()?;
|
223 |
+
Ok(segments)
|
224 |
+
} else {
|
225 |
+
Ok(vec![])
|
226 |
+
}
|
227 |
}
|
228 |
|
229 |
+
fn next_line(&mut self) -> Result<(), Error> {
|
230 |
+
|
231 |
+
// keep the last n_samples_keep samples from pcm_f32
|
232 |
+
if self.pcm_f32.len() > self.n_samples_keep {
|
233 |
+
let _ = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep)).len();
|
234 |
+
}
|
235 |
+
|
236 |
+
if !self.config.no_context {
|
237 |
+
self.prompt_tokens.clear();
|
238 |
+
|
239 |
+
let num_segments = self.state
|
240 |
+
.full_n_segments()
|
241 |
+
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
242 |
+
for i in 0..num_segments {
|
243 |
+
let token_count = self.state
|
244 |
+
.full_n_tokens(i)
|
245 |
+
.map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
|
246 |
+
for j in 0..token_count {
|
247 |
+
let token = self.state
|
248 |
+
.full_get_token_id(i, j)
|
249 |
+
.map_err(|e| Error::whisper_error("failed to get token", e))?;
|
250 |
+
self.prompt_tokens.push(token);
|
251 |
+
}
|
252 |
}
|
253 |
}
|
254 |
+
Ok(())
|
255 |
}
|
|
|
256 |
}
|
257 |
|
258 |
impl Drop for WhisperHandler {
|