Spaces:
Sleeping
Sleeping
WHISPER!
Browse files- config.yaml +4 -4
- src/config.rs +5 -4
- src/group.rs +9 -4
- src/main.rs +40 -33
- src/whisper.rs +66 -40
config.yaml
CHANGED
@@ -3,10 +3,10 @@ server:
|
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
# n_threads: 4
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
max_tokens:
|
10 |
audio_ctx: 0
|
11 |
vad_thold: 0.6
|
12 |
freq_thold: 100.0
|
|
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
# n_threads: 4
|
6 |
+
length_ms: 10000
|
7 |
+
keep_ms: 200
|
8 |
+
step_ms: 1000
|
9 |
+
max_tokens: 0
|
10 |
audio_ctx: 0
|
11 |
vad_thold: 0.6
|
12 |
freq_thold: 100.0
|
src/config.rs
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
use std::ffi::c_int;
|
|
|
2 |
use std::net::IpAddr;
|
3 |
use lazy_static::lazy_static;
|
4 |
use serde::{Deserialize};
|
@@ -24,8 +25,8 @@ lazy_static! {
|
|
24 |
#[derive(Debug, Deserialize, Clone)]
|
25 |
pub(crate) struct WhisperParams {
|
26 |
pub(crate) n_threads: Option<usize>,
|
27 |
-
|
28 |
-
|
29 |
pub(crate) keep_ms: u32,
|
30 |
pub(crate) max_tokens: u32,
|
31 |
pub(crate) audio_ctx: u32,
|
@@ -52,7 +53,7 @@ impl WhisperParams {
|
|
52 |
param.set_print_realtime(false);
|
53 |
param.set_print_timestamps(!self.no_timestamps);
|
54 |
param.set_translate(self.translate);
|
55 |
-
param.set_single_segment(
|
56 |
param.set_max_tokens(self.max_tokens as i32);
|
57 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
58 |
param.set_language(lang);
|
@@ -62,7 +63,7 @@ impl WhisperParams {
|
|
62 |
param.set_speed_up(self.speed_up);
|
63 |
// param.set_tdrz_enable(self.tinydiarize);
|
64 |
if self.no_fallback {
|
65 |
-
param.set_temperature_inc(
|
66 |
}
|
67 |
if self.no_context {
|
68 |
param.set_tokens(&NONE);
|
|
|
1 |
use std::ffi::c_int;
|
2 |
+
use std::fs;
|
3 |
use std::net::IpAddr;
|
4 |
use lazy_static::lazy_static;
|
5 |
use serde::{Deserialize};
|
|
|
25 |
#[derive(Debug, Deserialize, Clone)]
|
26 |
pub(crate) struct WhisperParams {
|
27 |
pub(crate) n_threads: Option<usize>,
|
28 |
+
pub(crate) step_ms: u32,
|
29 |
+
pub(crate) length_ms: u32,
|
30 |
pub(crate) keep_ms: u32,
|
31 |
pub(crate) max_tokens: u32,
|
32 |
pub(crate) audio_ctx: u32,
|
|
|
53 |
param.set_print_realtime(false);
|
54 |
param.set_print_timestamps(!self.no_timestamps);
|
55 |
param.set_translate(self.translate);
|
56 |
+
param.set_single_segment(true);
|
57 |
param.set_max_tokens(self.max_tokens as i32);
|
58 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
59 |
param.set_language(lang);
|
|
|
63 |
param.set_speed_up(self.speed_up);
|
64 |
// param.set_tdrz_enable(self.tinydiarize);
|
65 |
if self.no_fallback {
|
66 |
+
param.set_temperature_inc(-1.0);
|
67 |
}
|
68 |
if self.no_context {
|
69 |
param.set_tokens(&NONE);
|
src/group.rs
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
use std::time::Duration;
|
2 |
use tokio::{select};
|
3 |
use tokio::time::sleep;
|
@@ -17,13 +18,13 @@ where I: 'static + Send {
|
|
17 |
buffer: usize) -> Self {
|
18 |
let (tx, outlet) = channel::<Vec<I>>(buffer);
|
19 |
tokio::spawn(async move {
|
20 |
-
let mut window =
|
21 |
|
22 |
loop {
|
23 |
let grouped_fut = async {
|
24 |
while let Some(c) = inlet.recv().await {
|
25 |
window.extend(c);
|
26 |
-
if window.len()
|
27 |
let will_send: Vec<I> = window.drain(0..group_size).collect();
|
28 |
return Some(will_send)
|
29 |
}
|
@@ -31,18 +32,22 @@ where I: 'static + Send {
|
|
31 |
return None
|
32 |
};
|
33 |
|
34 |
-
let grouped = select! {
|
35 |
_ = sleep(window_time) => {
|
36 |
window.drain(..).collect()
|
37 |
},
|
38 |
grouped_opt = grouped_fut => {
|
39 |
match grouped_opt {
|
40 |
None => break,
|
41 |
-
Some(
|
42 |
}
|
43 |
}
|
44 |
};
|
45 |
|
|
|
|
|
|
|
|
|
46 |
if let Err(e) = tx.send(grouped).await {
|
47 |
tracing::error!("{}", e);
|
48 |
}
|
|
|
1 |
+
use std::collections::VecDeque;
|
2 |
use std::time::Duration;
|
3 |
use tokio::{select};
|
4 |
use tokio::time::sleep;
|
|
|
18 |
buffer: usize) -> Self {
|
19 |
let (tx, outlet) = channel::<Vec<I>>(buffer);
|
20 |
tokio::spawn(async move {
|
21 |
+
let mut window = VecDeque::with_capacity(group_size);
|
22 |
|
23 |
loop {
|
24 |
let grouped_fut = async {
|
25 |
while let Some(c) = inlet.recv().await {
|
26 |
window.extend(c);
|
27 |
+
if window.len() >= group_size {
|
28 |
let will_send: Vec<I> = window.drain(0..group_size).collect();
|
29 |
return Some(will_send)
|
30 |
}
|
|
|
32 |
return None
|
33 |
};
|
34 |
|
35 |
+
let grouped: Vec<I> = select! {
|
36 |
_ = sleep(window_time) => {
|
37 |
window.drain(..).collect()
|
38 |
},
|
39 |
grouped_opt = grouped_fut => {
|
40 |
match grouped_opt {
|
41 |
None => break,
|
42 |
+
Some(group) => group
|
43 |
}
|
44 |
}
|
45 |
};
|
46 |
|
47 |
+
if grouped.is_empty() {
|
48 |
+
continue
|
49 |
+
}
|
50 |
+
|
51 |
if let Err(e) = tx.send(grouped).await {
|
52 |
tracing::error!("{}", e);
|
53 |
}
|
src/main.rs
CHANGED
@@ -35,14 +35,6 @@ struct Opt {
|
|
35 |
/// The AWS Region.
|
36 |
#[structopt(short, long)]
|
37 |
region: Option<String>,
|
38 |
-
//
|
39 |
-
// /// The name of the audio file.
|
40 |
-
// #[structopt(short, long)]
|
41 |
-
// audio_file: String,
|
42 |
-
//
|
43 |
-
/// Whether to display additional information.
|
44 |
-
#[structopt(short, long)]
|
45 |
-
verbose: bool,
|
46 |
}
|
47 |
|
48 |
#[derive(Clone)]
|
@@ -57,22 +49,18 @@ async fn main() -> Result<(), std::io::Error> {
|
|
57 |
|
58 |
let Opt {
|
59 |
region,
|
60 |
-
verbose,
|
61 |
} = Opt::parse();
|
62 |
|
63 |
let region_provider = RegionProviderChain::first_try(region.map(Region::new))
|
64 |
.or_default_provider()
|
65 |
.or_else(Region::new("us-west-2"));
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
println!("Transcribe client version: {}", PKG_VERSION);
|
71 |
-
println!(
|
72 |
"Region: {}",
|
73 |
region_provider.region().await.unwrap().as_ref()
|
74 |
);
|
75 |
-
println!();
|
76 |
}
|
77 |
|
78 |
let shared_config = aws_config::from_env().region(region_provider).load().await;
|
@@ -118,36 +106,37 @@ async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, w
|
|
118 |
loop {
|
119 |
select! {
|
120 |
w = whisper_transcribe_rx.recv() => {
|
121 |
-
let Ok(
|
|
|
122 |
continue
|
123 |
};
|
124 |
-
println!("Whisper: {:?}", txt)
|
125 |
}
|
126 |
msg = socket.next() => {
|
127 |
match msg.as_ref() {
|
128 |
Some(Ok(Message::Binary(bin))) => {
|
129 |
-
let _ = whisper.send(bin.
|
130 |
-
if origin_tx.send(bin.to_vec()).await
|
131 |
-
|
132 |
break;
|
133 |
}
|
134 |
},
|
135 |
Some(Ok(_)) => {
|
136 |
-
|
137 |
},
|
138 |
Some(Err(e)) => {
|
139 |
-
|
140 |
},
|
141 |
None => {
|
142 |
-
let
|
143 |
-
|
|
|
144 |
break;
|
145 |
}
|
146 |
}
|
147 |
},
|
148 |
output = transcribe_rx.recv() => {
|
149 |
if let Ok(transcript) = output {
|
150 |
-
|
151 |
let evt = LiveLessonTextEvent::Transcription { text: transcript.clone() };
|
152 |
let json = serde_json::to_string(&evt).expect("failed to serialize");
|
153 |
let _ = socket.send(Message::Text(json)).await.expect("failed to send");
|
@@ -177,10 +166,16 @@ enum LiveLessonTextEvent {
|
|
177 |
#[handler]
|
178 |
async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>, ws: WebSocket) -> impl IntoResponse {
|
179 |
let lesson_opt = ctx.lessons_manager.get_lesson(query.id).await;
|
180 |
-
|
181 |
-
let voice_id = query.voice.parse().expect("Not supported voice");
|
182 |
|
183 |
ws.on_upgrade(|mut socket| async move {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
let Some(lesson) = lesson_opt else {
|
185 |
let _ = socket.send(Message::Text("lesson not found".to_string())).await;
|
186 |
return
|
@@ -197,17 +192,29 @@ async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>,
|
|
197 |
transcript = transcript_rx.recv() => {
|
198 |
if let Ok(transcript) = transcript {
|
199 |
let evt = LiveLessonTextEvent::Transcription { text: transcript };
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
}
|
204 |
},
|
205 |
translated = translate_rx.recv() => {
|
206 |
if let Ok(translated) = translated {
|
207 |
let evt = LiveLessonTextEvent::Translation { text: translated };
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
}
|
212 |
},
|
213 |
voice = voice_rx.recv() => {
|
|
|
35 |
/// The AWS Region.
|
36 |
#[structopt(short, long)]
|
37 |
region: Option<String>,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
}
|
39 |
|
40 |
#[derive(Clone)]
|
|
|
49 |
|
50 |
let Opt {
|
51 |
region,
|
|
|
52 |
} = Opt::parse();
|
53 |
|
54 |
let region_provider = RegionProviderChain::first_try(region.map(Region::new))
|
55 |
.or_default_provider()
|
56 |
.or_else(Region::new("us-west-2"));
|
57 |
|
58 |
+
if tracing::enabled!(tracing::Level::DEBUG) {
|
59 |
+
tracing::debug!("Transcribe client version: {}", PKG_VERSION);
|
60 |
+
tracing::debug!(
|
|
|
|
|
61 |
"Region: {}",
|
62 |
region_provider.region().await.unwrap().as_ref()
|
63 |
);
|
|
|
64 |
}
|
65 |
|
66 |
let shared_config = aws_config::from_env().region(region_provider).load().await;
|
|
|
106 |
loop {
|
107 |
select! {
|
108 |
w = whisper_transcribe_rx.recv() => {
|
109 |
+
let Ok(_txt) = w else {
|
110 |
+
// TODO: handle msg
|
111 |
continue
|
112 |
};
|
|
|
113 |
}
|
114 |
msg = socket.next() => {
|
115 |
match msg.as_ref() {
|
116 |
Some(Ok(Message::Binary(bin))) => {
|
117 |
+
let _ = whisper.send(bin.to_vec()).await; // whisper test
|
118 |
+
if let Err(e) = origin_tx.send(bin.to_vec()).await {
|
119 |
+
tracing::warn!("failed to send voice: {}", e);
|
120 |
break;
|
121 |
}
|
122 |
},
|
123 |
Some(Ok(_)) => {
|
124 |
+
tracing::warn!("Other: {:?}", msg);
|
125 |
},
|
126 |
Some(Err(e)) => {
|
127 |
+
tracing::warn!("Error: {:?}", e);
|
128 |
},
|
129 |
None => {
|
130 |
+
if let Err(e) = socket.close().await {
|
131 |
+
tracing::debug!("Message: {:?}, {}", msg, e);
|
132 |
+
}
|
133 |
break;
|
134 |
}
|
135 |
}
|
136 |
},
|
137 |
output = transcribe_rx.recv() => {
|
138 |
if let Ok(transcript) = output {
|
139 |
+
tracing::trace!("Transcribed: {}", transcript);
|
140 |
let evt = LiveLessonTextEvent::Transcription { text: transcript.clone() };
|
141 |
let json = serde_json::to_string(&evt).expect("failed to serialize");
|
142 |
let _ = socket.send(Message::Text(json)).await.expect("failed to send");
|
|
|
166 |
#[handler]
|
167 |
async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>, ws: WebSocket) -> impl IntoResponse {
|
168 |
let lesson_opt = ctx.lessons_manager.get_lesson(query.id).await;
|
169 |
+
tracing::debug!("listener param = {:?}", query);
|
|
|
170 |
|
171 |
ws.on_upgrade(|mut socket| async move {
|
172 |
+
let voice_id = match query.voice.parse() {
|
173 |
+
Ok(id) => id,
|
174 |
+
Err(e) => {
|
175 |
+
let _ = socket.send(Message::Text(format!("invalid voice id: {}", e))).await;
|
176 |
+
return
|
177 |
+
}
|
178 |
+
};
|
179 |
let Some(lesson) = lesson_opt else {
|
180 |
let _ = socket.send(Message::Text("lesson not found".to_string())).await;
|
181 |
return
|
|
|
192 |
transcript = transcript_rx.recv() => {
|
193 |
if let Ok(transcript) = transcript {
|
194 |
let evt = LiveLessonTextEvent::Transcription { text: transcript };
|
195 |
+
match serde_json::to_string(&evt) {
|
196 |
+
Ok(json) => {
|
197 |
+
tracing::debug!("Transcribed: {}", json);
|
198 |
+
let _ = socket.send(Message::Text(json)).await;
|
199 |
+
},
|
200 |
+
Err(e) => {
|
201 |
+
tracing::error!("failed to serialize: {}", e);
|
202 |
+
}
|
203 |
+
}
|
204 |
}
|
205 |
},
|
206 |
translated = translate_rx.recv() => {
|
207 |
if let Ok(translated) = translated {
|
208 |
let evt = LiveLessonTextEvent::Translation { text: translated };
|
209 |
+
match serde_json::to_string(&evt) {
|
210 |
+
Ok(json) => {
|
211 |
+
tracing::debug!("Translated: {}", json);
|
212 |
+
let _ = socket.send(Message::Text(json)).await;
|
213 |
+
},
|
214 |
+
Err(e) => {
|
215 |
+
tracing::error!("failed to serialize: {}", e);
|
216 |
+
}
|
217 |
+
}
|
218 |
}
|
219 |
},
|
220 |
voice = voice_rx.recv() => {
|
src/whisper.rs
CHANGED
@@ -1,10 +1,11 @@
|
|
|
|
1 |
use std::ffi::c_int;
|
2 |
use std::fmt::{Debug, Display, Formatter};
|
3 |
use std::thread::sleep;
|
4 |
use std::time::Duration;
|
5 |
use lazy_static::lazy_static;
|
6 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
7 |
-
use whisper_rs::{convert_integer_to_float_audio, WhisperContext,
|
8 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
9 |
use crate::config::{WhisperParams, CONFIG};
|
10 |
use crate::group::GroupedWithin;
|
@@ -78,22 +79,28 @@ pub struct WhisperHandler {
|
|
78 |
|
79 |
impl WhisperHandler {
|
80 |
pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
|
|
|
|
|
81 |
let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
|
|
|
|
|
|
82 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
83 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
84 |
let mut grouped = GroupedWithin::new(
|
85 |
-
|
86 |
-
Duration::
|
87 |
pcm_rx,
|
88 |
-
|
89 |
);
|
90 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
91 |
let shared_transcription_tx = transcription_tx.clone();
|
92 |
let mut state = WHISPER_CONTEXT.create_state()
|
93 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
94 |
tokio::task::spawn_blocking(move || {
|
95 |
-
let mut
|
96 |
-
let mut pcm_f32:
|
|
|
97 |
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
|
98 |
let new_pcm_f32 = match grouped.next() {
|
99 |
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
@@ -106,22 +113,39 @@ impl WhisperHandler {
|
|
106 |
}
|
107 |
};
|
108 |
|
|
|
109 |
pcm_f32.extend(new_pcm_f32);
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
continue
|
114 |
}
|
115 |
-
|
116 |
-
tracing::error!("failed to send transcription: {}", e);
|
117 |
-
break
|
118 |
-
}
|
119 |
}
|
120 |
Err(err) => {
|
121 |
-
tracing::
|
122 |
continue
|
123 |
-
// break
|
124 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
}
|
126 |
}
|
127 |
});
|
@@ -143,17 +167,13 @@ impl WhisperHandler {
|
|
143 |
|
144 |
fn inference(
|
145 |
state: &mut WhisperState,
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
let
|
152 |
-
|
153 |
-
let st = std::time::Instant::now();
|
154 |
-
let _ = state.full(params, pcm_f32.as_slice())
|
155 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
156 |
-
let et = std::time::Instant::now();
|
157 |
|
158 |
let num_segments = state
|
159 |
.full_n_segments()
|
@@ -163,28 +183,36 @@ fn inference(
|
|
163 |
let segment = state
|
164 |
.full_get_segment_text(i)
|
165 |
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
166 |
-
let
|
|
|
167 |
.full_get_segment_t0(i)
|
168 |
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
169 |
-
let end_timestamp = state
|
170 |
.full_get_segment_t1(i)
|
171 |
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
172 |
-
tracing::
|
173 |
segments.push(Segment { start_timestamp, end_timestamp, text: segment });
|
174 |
}
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
if !config.no_context {
|
177 |
prompt_tokens.clear();
|
178 |
|
179 |
-
|
180 |
-
if pcm_f32.len() > n_samples_keep {
|
181 |
-
let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep)).collect::<Vec<_>>();
|
182 |
-
}
|
183 |
-
|
184 |
-
let n_segments = state
|
185 |
-
.full_n_segments()
|
186 |
-
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
187 |
-
for i in 0..n_segments {
|
188 |
let token_count = state
|
189 |
.full_n_tokens(i)
|
190 |
.map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
|
@@ -196,9 +224,7 @@ fn inference(
|
|
196 |
}
|
197 |
}
|
198 |
}
|
199 |
-
|
200 |
-
tracing::trace!("took {}ms", (et - st).as_millis());
|
201 |
-
Ok(segments)
|
202 |
}
|
203 |
|
204 |
impl Drop for WhisperHandler {
|
|
|
1 |
+
use std::collections::VecDeque;
|
2 |
use std::ffi::c_int;
|
3 |
use std::fmt::{Debug, Display, Formatter};
|
4 |
use std::thread::sleep;
|
5 |
use std::time::Duration;
|
6 |
use lazy_static::lazy_static;
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
+
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext, FullParams};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
use crate::config::{WhisperParams, CONFIG};
|
11 |
use crate::group::GroupedWithin;
|
|
|
79 |
|
80 |
impl WhisperHandler {
|
81 |
pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
|
82 |
+
let n_samples_step: usize = (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
83 |
+
let n_samples_len: usize = (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
84 |
let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
85 |
+
let n_new_line: usize = 1.max(config.length_ms / config.step_ms - 1) as usize;
|
86 |
+
let mut n_iter: usize = 0;
|
87 |
+
|
88 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
89 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
90 |
let mut grouped = GroupedWithin::new(
|
91 |
+
n_samples_step * 2,
|
92 |
+
Duration::from_millis(config.step_ms as u64),
|
93 |
pcm_rx,
|
94 |
+
u16::MAX as usize
|
95 |
);
|
96 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
97 |
let shared_transcription_tx = transcription_tx.clone();
|
98 |
let mut state = WHISPER_CONTEXT.create_state()
|
99 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
100 |
tokio::task::spawn_blocking(move || {
|
101 |
+
let mut prompt_tokens: Vec<c_int> = Default::default();
|
102 |
+
let mut pcm_f32: VecDeque<f32> = VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]);
|
103 |
+
let mut last_offset: usize = 0;
|
104 |
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
|
105 |
let new_pcm_f32 = match grouped.next() {
|
106 |
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
|
|
113 |
}
|
114 |
};
|
115 |
|
116 |
+
let params = config.to_full_params(prompt_tokens.as_slice());
|
117 |
pcm_f32.extend(new_pcm_f32);
|
118 |
+
if pcm_f32.len() > n_samples_len + n_samples_keep {
|
119 |
+
let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep - n_samples_len)).len();
|
120 |
+
}
|
121 |
+
pcm_f32.make_contiguous();
|
122 |
+
let (data, _) = pcm_f32.as_slices();
|
123 |
+
let segments = match inference(&mut state, params, data, 0) {
|
124 |
+
Ok((offset, result)) => {
|
125 |
+
last_offset = offset;
|
126 |
+
if result.is_empty() {
|
127 |
continue
|
128 |
}
|
129 |
+
result
|
|
|
|
|
|
|
130 |
}
|
131 |
Err(err) => {
|
132 |
+
tracing::warn!("failed to inference: {}", err);
|
133 |
continue
|
|
|
134 |
}
|
135 |
+
};
|
136 |
+
|
137 |
+
n_iter = n_iter + 1;
|
138 |
+
|
139 |
+
if n_iter % n_new_line == 0 {
|
140 |
+
if let Err(e) = new_line(&mut pcm_f32, n_samples_keep, &config, &mut state, segments.len(), &mut prompt_tokens) {
|
141 |
+
tracing::warn!("failed to new_line: {}", e);
|
142 |
+
}
|
143 |
+
tracing::debug!("LINE: {}", segments.first().unwrap().text);
|
144 |
+
|
145 |
+
if let Err(e) = shared_transcription_tx.send(segments) {
|
146 |
+
tracing::error!("failed to send transcription: {}", e);
|
147 |
+
break
|
148 |
+
};
|
149 |
}
|
150 |
}
|
151 |
});
|
|
|
167 |
|
168 |
fn inference(
|
169 |
state: &mut WhisperState,
|
170 |
+
params: FullParams,
|
171 |
+
pcm_f32: &[f32],
|
172 |
+
mut offset: usize,
|
173 |
+
) -> Result<(usize, Vec<Segment>), Error> {
|
174 |
+
|
175 |
+
let _ = state.full(params, pcm_f32)
|
|
|
|
|
|
|
176 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
|
|
177 |
|
178 |
let num_segments = state
|
179 |
.full_n_segments()
|
|
|
183 |
let segment = state
|
184 |
.full_get_segment_text(i)
|
185 |
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
186 |
+
let timestamp_offset: i64 = (offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
|
187 |
+
let start_timestamp: i64 = timestamp_offset + 10 * state
|
188 |
.full_get_segment_t0(i)
|
189 |
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
190 |
+
let end_timestamp: i64 = timestamp_offset + 10 * state
|
191 |
.full_get_segment_t1(i)
|
192 |
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
193 |
+
// tracing::trace!("{}", segment);
|
194 |
segments.push(Segment { start_timestamp, end_timestamp, text: segment });
|
195 |
}
|
196 |
|
197 |
+
Ok((offset, segments))
|
198 |
+
}
|
199 |
+
|
200 |
+
fn new_line(pcm_f32: &mut VecDeque<f32>,
|
201 |
+
n_samples_keep: usize,
|
202 |
+
config: &WhisperParams,
|
203 |
+
state: &mut WhisperState,
|
204 |
+
num_segments: usize,
|
205 |
+
prompt_tokens: &mut Vec<c_int>) -> Result<(), Error> {
|
206 |
+
|
207 |
+
// keep the last n_samples_keep samples from pcm_f32
|
208 |
+
if pcm_f32.len() > n_samples_keep {
|
209 |
+
let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep)).len();
|
210 |
+
}
|
211 |
+
|
212 |
if !config.no_context {
|
213 |
prompt_tokens.clear();
|
214 |
|
215 |
+
for i in 0..num_segments as c_int {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
let token_count = state
|
217 |
.full_n_tokens(i)
|
218 |
.map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
|
|
|
224 |
}
|
225 |
}
|
226 |
}
|
227 |
+
Ok(())
|
|
|
|
|
228 |
}
|
229 |
|
230 |
impl Drop for WhisperHandler {
|