Spaces:
Sleeping
Sleeping
Context
Browse files- Cargo.lock +6 -6
- Cargo.toml +17 -17
- config.yaml +19 -15
- src/config.rs +18 -18
- src/whisper.rs +63 -54
Cargo.lock
CHANGED
@@ -200,9 +200,9 @@ dependencies = [
|
|
200 |
|
201 |
[[package]]
|
202 |
name = "aws-sdk-polly"
|
203 |
-
version = "0.
|
204 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
205 |
-
checksum = "
|
206 |
dependencies = [
|
207 |
"aws-credential-types",
|
208 |
"aws-http",
|
@@ -273,9 +273,9 @@ dependencies = [
|
|
273 |
|
274 |
[[package]]
|
275 |
name = "aws-sdk-transcribestreaming"
|
276 |
-
version = "0.
|
277 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
278 |
-
checksum = "
|
279 |
dependencies = [
|
280 |
"aws-credential-types",
|
281 |
"aws-http",
|
@@ -299,9 +299,9 @@ dependencies = [
|
|
299 |
|
300 |
[[package]]
|
301 |
name = "aws-sdk-translate"
|
302 |
-
version = "0.
|
303 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
304 |
-
checksum = "
|
305 |
dependencies = [
|
306 |
"aws-credential-types",
|
307 |
"aws-http",
|
|
|
200 |
|
201 |
[[package]]
|
202 |
name = "aws-sdk-polly"
|
203 |
+
version = "0.34.0"
|
204 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
205 |
+
checksum = "6950a413936042ea34341331eb2fe7635f33d760ffff3dd11afa4ad35d5151a5"
|
206 |
dependencies = [
|
207 |
"aws-credential-types",
|
208 |
"aws-http",
|
|
|
273 |
|
274 |
[[package]]
|
275 |
name = "aws-sdk-transcribestreaming"
|
276 |
+
version = "0.34.0"
|
277 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
278 |
+
checksum = "bf1b3a068ad56440585dccc3210a537bbe9653b07897698339e0d04c234463d2"
|
279 |
dependencies = [
|
280 |
"aws-credential-types",
|
281 |
"aws-http",
|
|
|
299 |
|
300 |
[[package]]
|
301 |
name = "aws-sdk-translate"
|
302 |
+
version = "0.34.0"
|
303 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
304 |
+
checksum = "23f27dd4ee2f7aeb6736cb24bad9f0ef0bb08194bd598d8fa8e96261c6689409"
|
305 |
dependencies = [
|
306 |
"aws-credential-types",
|
307 |
"aws-http",
|
Cargo.toml
CHANGED
@@ -5,25 +5,25 @@ edition = "2021"
|
|
5 |
|
6 |
[dependencies]
|
7 |
aws-config= { version = "0.56.1" }
|
8 |
-
aws-sdk-transcribestreaming= { version = "0.
|
9 |
-
aws-sdk-translate = "0.
|
10 |
-
aws-sdk-polly = "0.
|
11 |
-
clap = { version = "4.4
|
12 |
-
tokio = { version = "1.33
|
13 |
-
tokio-stream = "0.1
|
14 |
-
async-stream = "0.3
|
15 |
-
futures-util = "0.3
|
16 |
-
serde = { version = "1.0
|
17 |
-
serde_json = { version = "1.0
|
18 |
-
serde_yaml = "0.9
|
19 |
-
whisper-rs = { version = "0.8
|
20 |
-
whisper-rs-sys = "0.6
|
21 |
-
tracing = "0.1
|
22 |
-
tracing-subscriber = "0.3
|
23 |
-
lazy_static = "1.4
|
24 |
|
25 |
[dependencies.poem]
|
26 |
-
version = "1.3
|
27 |
features = ["websocket", "static-files"]
|
28 |
|
29 |
[target.aarch64-apple-darwin.dependencies.whisper-rs]
|
|
|
5 |
|
6 |
[dependencies]
|
7 |
aws-config= { version = "0.56.1" }
|
8 |
+
aws-sdk-transcribestreaming= { version = "0.34" }
|
9 |
+
aws-sdk-translate = "0.34"
|
10 |
+
aws-sdk-polly = "0.34"
|
11 |
+
clap = { version = "4.4" , features = ["derive"]}
|
12 |
+
tokio = { version = "1.33" , features = ["full"] }
|
13 |
+
tokio-stream = "0.1"
|
14 |
+
async-stream = "0.3"
|
15 |
+
futures-util = "0.3"
|
16 |
+
serde = { version = "1.0", features = ["derive"] }
|
17 |
+
serde_json = { version = "1.0" }
|
18 |
+
serde_yaml = "0.9"
|
19 |
+
whisper-rs = { version = "0.8" }
|
20 |
+
whisper-rs-sys = "0.6"
|
21 |
+
tracing = "0.1"
|
22 |
+
tracing-subscriber = "0.3"
|
23 |
+
lazy_static = "1.4"
|
24 |
|
25 |
[dependencies.poem]
|
26 |
+
version = "1.3"
|
27 |
features = ["websocket", "static-files"]
|
28 |
|
29 |
[target.aarch64-apple-darwin.dependencies.whisper-rs]
|
config.yaml
CHANGED
@@ -2,20 +2,24 @@ server:
|
|
2 |
port: 8080
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
-
|
6 |
-
length_ms: 5000
|
7 |
keep_ms: 200
|
8 |
step_ms: 5000
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
port: 8080
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
+
length_ms: 10000
|
|
|
6 |
keep_ms: 200
|
7 |
step_ms: 5000
|
8 |
+
model: "models/ggml-base.bin"
|
9 |
+
max_prompt_tokens: 128
|
10 |
+
params:
|
11 |
+
#n_threads: 4
|
12 |
+
max_tokens: 0
|
13 |
+
audio_ctx: 0
|
14 |
+
speed_up: false
|
15 |
+
single_segment: false
|
16 |
+
translate: false
|
17 |
+
no_fallback: false
|
18 |
+
temperature_inc: -1.0
|
19 |
+
print_special: false
|
20 |
+
print_progress: false
|
21 |
+
print_realtime: false
|
22 |
+
no_context: false
|
23 |
+
no_timestamps: false
|
24 |
+
tinydiarize: false
|
25 |
+
language: "en"
|
src/config.rs
CHANGED
@@ -22,24 +22,31 @@ lazy_static! {
|
|
22 |
}
|
23 |
|
24 |
#[derive(Debug, Deserialize, Clone)]
|
25 |
-
pub(crate) struct
|
26 |
-
pub(crate)
|
27 |
pub(crate) step_ms: u32,
|
28 |
pub(crate) length_ms: u32,
|
29 |
pub(crate) keep_ms: u32,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
pub(crate) max_tokens: u32,
|
31 |
pub(crate) audio_ctx: u32,
|
32 |
-
// pub(crate) vad_thold: f32,
|
33 |
-
// pub(crate) freq_thold: f32,
|
34 |
pub(crate) speed_up: bool,
|
35 |
pub(crate) translate: bool,
|
36 |
pub(crate) no_fallback: bool,
|
37 |
pub(crate) print_special: bool,
|
38 |
-
pub(crate)
|
|
|
39 |
pub(crate) no_timestamps: bool,
|
|
|
|
|
40 |
// pub(crate) tinydiarize: bool,
|
41 |
pub(crate) language: Option<String>,
|
42 |
-
pub(crate) model: String,
|
43 |
}
|
44 |
|
45 |
const NONE: [c_int;0] = [];
|
@@ -47,12 +54,12 @@ const NONE: [c_int;0] = [];
|
|
47 |
impl WhisperParams {
|
48 |
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
|
49 |
let mut param = FullParams::new(Default::default());
|
50 |
-
param.set_print_progress(
|
51 |
param.set_print_special(self.print_special);
|
52 |
-
param.set_print_realtime(
|
53 |
param.set_print_timestamps(!self.no_timestamps);
|
54 |
param.set_translate(self.translate);
|
55 |
-
param.set_single_segment(
|
56 |
param.set_max_tokens(self.max_tokens as i32);
|
57 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
58 |
param.set_language(lang);
|
@@ -61,14 +68,7 @@ impl WhisperParams {
|
|
61 |
param.set_audio_ctx(self.audio_ctx as i32);
|
62 |
param.set_speed_up(self.speed_up);
|
63 |
// param.set_tdrz_enable(self.tinydiarize);
|
64 |
-
|
65 |
-
param.set_temperature_inc(-1.0);
|
66 |
-
}
|
67 |
-
if self.no_context {
|
68 |
-
param.set_tokens(&NONE);
|
69 |
-
} else {
|
70 |
-
param.set_tokens(&tokens);
|
71 |
-
}
|
72 |
|
73 |
param
|
74 |
}
|
@@ -82,7 +82,7 @@ pub(crate) struct Server {
|
|
82 |
|
83 |
#[derive(Debug, Deserialize)]
|
84 |
pub struct Config {
|
85 |
-
pub(crate) whisper:
|
86 |
pub(crate) server: Server,
|
87 |
}
|
88 |
|
|
|
22 |
}
|
23 |
|
24 |
#[derive(Debug, Deserialize, Clone)]
|
25 |
+
pub(crate) struct WhisperConfig {
|
26 |
+
pub(crate) params: WhisperParams,
|
27 |
pub(crate) step_ms: u32,
|
28 |
pub(crate) length_ms: u32,
|
29 |
pub(crate) keep_ms: u32,
|
30 |
+
pub(crate) model: String,
|
31 |
+
pub(crate) max_prompt_tokens: usize,
|
32 |
+
}
|
33 |
+
|
34 |
+
#[derive(Debug, Deserialize, Clone)]
|
35 |
+
pub(crate) struct WhisperParams {
|
36 |
+
pub(crate) n_threads: Option<usize>,
|
37 |
pub(crate) max_tokens: u32,
|
38 |
pub(crate) audio_ctx: u32,
|
|
|
|
|
39 |
pub(crate) speed_up: bool,
|
40 |
pub(crate) translate: bool,
|
41 |
pub(crate) no_fallback: bool,
|
42 |
pub(crate) print_special: bool,
|
43 |
+
pub(crate) print_realtime: bool,
|
44 |
+
pub(crate) print_progress: bool,
|
45 |
pub(crate) no_timestamps: bool,
|
46 |
+
pub(crate) temperature_inc: f32,
|
47 |
+
pub(crate) single_segment: bool,
|
48 |
// pub(crate) tinydiarize: bool,
|
49 |
pub(crate) language: Option<String>,
|
|
|
50 |
}
|
51 |
|
52 |
const NONE: [c_int;0] = [];
|
|
|
54 |
impl WhisperParams {
|
55 |
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
|
56 |
let mut param = FullParams::new(Default::default());
|
57 |
+
param.set_print_progress(self.print_progress);
|
58 |
param.set_print_special(self.print_special);
|
59 |
+
param.set_print_realtime(self.print_realtime);
|
60 |
param.set_print_timestamps(!self.no_timestamps);
|
61 |
param.set_translate(self.translate);
|
62 |
+
param.set_single_segment(false);
|
63 |
param.set_max_tokens(self.max_tokens as i32);
|
64 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
65 |
param.set_language(lang);
|
|
|
68 |
param.set_audio_ctx(self.audio_ctx as i32);
|
69 |
param.set_speed_up(self.speed_up);
|
70 |
// param.set_tdrz_enable(self.tinydiarize);
|
71 |
+
param.set_temperature_inc(self.temperature_inc);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
param
|
74 |
}
|
|
|
82 |
|
83 |
#[derive(Debug, Deserialize)]
|
84 |
pub struct Config {
|
85 |
+
pub(crate) whisper: WhisperConfig,
|
86 |
pub(crate) server: Server,
|
87 |
}
|
88 |
|
src/whisper.rs
CHANGED
@@ -7,7 +7,7 @@ use lazy_static::lazy_static;
|
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
-
use crate::config::{WhisperParams, CONFIG};
|
11 |
use crate::group::GroupedWithin;
|
12 |
|
13 |
lazy_static! {
|
@@ -69,6 +69,7 @@ pub struct Segment {
|
|
69 |
pub start_timestamp: i64,
|
70 |
pub end_timestamp: i64,
|
71 |
pub text: String,
|
|
|
72 |
}
|
73 |
|
74 |
pub struct WhisperHandler {
|
@@ -78,7 +79,7 @@ pub struct WhisperHandler {
|
|
78 |
}
|
79 |
|
80 |
impl WhisperHandler {
|
81 |
-
pub(crate) fn new(config:
|
82 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
83 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
84 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
@@ -121,10 +122,11 @@ impl WhisperHandler {
|
|
121 |
|
122 |
if tracing::enabled!(tracing::Level::TRACE) {
|
123 |
for segment in segments.iter() {
|
124 |
-
tracing::trace!("[{}] SEGMENT: {}",
|
|
|
|
|
|
|
125 |
}
|
126 |
-
} else if tracing::enabled!(tracing::Level::DEBUG) {
|
127 |
-
tracing::debug!("[{}] SEGMENT: {}", detector.n_iter, segments[0].text);
|
128 |
}
|
129 |
|
130 |
if let Err(e) = shared_transcription_tx.send(segments) {
|
@@ -151,107 +153,114 @@ impl WhisperHandler {
|
|
151 |
|
152 |
struct Detector {
|
153 |
state: WhisperState<'static>,
|
154 |
-
config: &'static
|
155 |
n_samples_keep: usize,
|
156 |
n_samples_step: usize,
|
157 |
n_samples_len: usize,
|
158 |
-
n_new_line: usize,
|
159 |
-
n_iter: usize,
|
160 |
prompt_tokens: Vec<c_int>,
|
161 |
pcm_f32: VecDeque<f32>,
|
162 |
offset: usize,
|
|
|
163 |
}
|
164 |
|
165 |
impl Detector {
|
166 |
fn new(state: WhisperState<'static>,
|
167 |
-
config: &'static
|
168 |
Detector {
|
169 |
state,
|
170 |
config,
|
171 |
n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
172 |
n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
173 |
n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
174 |
-
n_new_line: 1.max(config.length_ms / config.step_ms - 1) as usize,
|
175 |
-
n_iter: 0,
|
176 |
prompt_tokens: Default::default(),
|
177 |
pcm_f32: VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]),
|
178 |
offset: 0,
|
|
|
179 |
}
|
180 |
}
|
181 |
|
182 |
fn feed(&mut self, new_pcm_f32: Vec<f32>) {
|
183 |
self.pcm_f32.extend(new_pcm_f32);
|
184 |
-
if self.pcm_f32.len()
|
185 |
-
|
186 |
}
|
|
|
|
|
187 |
}
|
188 |
|
189 |
fn inference(&mut self) -> Result<Vec<Segment>, Error> {
|
190 |
-
let params = self.config.to_full_params(self.prompt_tokens.as_slice());
|
191 |
let start = std::time::Instant::now();
|
192 |
let _ = self.state.full(params, self.pcm_f32.make_contiguous())
|
193 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
194 |
let end = std::time::Instant::now();
|
195 |
if end - start > Duration::from_millis(self.config.step_ms as u64) {
|
196 |
-
tracing::warn!("full() took {} ms too slow", (end - start).as_millis());
|
197 |
}
|
198 |
|
|
|
|
|
199 |
let num_segments = self.state
|
200 |
.full_n_segments()
|
201 |
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
202 |
let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
|
203 |
for i in 0..num_segments {
|
204 |
-
let segment = self.state
|
205 |
-
.full_get_segment_text(i)
|
206 |
-
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
207 |
-
let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
|
208 |
-
let start_timestamp: i64 = timestamp_offset + 10 * self.state
|
209 |
-
.full_get_segment_t0(i)
|
210 |
-
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
211 |
let end_timestamp: i64 = timestamp_offset + 10 * self.state
|
212 |
.full_get_segment_t1(i)
|
213 |
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
}
|
217 |
|
|
|
|
|
|
|
218 |
|
219 |
-
|
|
|
|
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
Ok(vec![])
|
226 |
-
}
|
227 |
}
|
228 |
|
229 |
-
fn
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
let _ = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep)).len();
|
234 |
}
|
235 |
-
|
236 |
-
|
237 |
-
self.prompt_tokens.clear();
|
238 |
-
|
239 |
-
let num_segments = self.state
|
240 |
-
.full_n_segments()
|
241 |
-
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
242 |
-
for i in 0..num_segments {
|
243 |
-
let token_count = self.state
|
244 |
-
.full_n_tokens(i)
|
245 |
-
.map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
|
246 |
-
for j in 0..token_count {
|
247 |
-
let token = self.state
|
248 |
-
.full_get_token_id(i, j)
|
249 |
-
.map_err(|e| Error::whisper_error("failed to get token", e))?;
|
250 |
-
self.prompt_tokens.push(token);
|
251 |
-
}
|
252 |
-
}
|
253 |
}
|
254 |
-
Ok(())
|
255 |
}
|
256 |
}
|
257 |
|
|
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
+
use crate::config::{WhisperParams, CONFIG, WhisperConfig};
|
11 |
use crate::group::GroupedWithin;
|
12 |
|
13 |
lazy_static! {
|
|
|
69 |
pub start_timestamp: i64,
|
70 |
pub end_timestamp: i64,
|
71 |
pub text: String,
|
72 |
+
tokens: Vec<c_int>,
|
73 |
}
|
74 |
|
75 |
pub struct WhisperHandler {
|
|
|
79 |
}
|
80 |
|
81 |
impl WhisperHandler {
|
82 |
+
pub(crate) fn new(config: WhisperConfig) -> Result<Self, Error> {
|
83 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
84 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
85 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
|
|
122 |
|
123 |
if tracing::enabled!(tracing::Level::TRACE) {
|
124 |
for segment in segments.iter() {
|
125 |
+
tracing::trace!("[{}-{}]s SEGMENT: {}",
|
126 |
+
segment.start_timestamp as f32 / 1000.0,
|
127 |
+
segment.end_timestamp as f32 / 1000.0,
|
128 |
+
segment.text);
|
129 |
}
|
|
|
|
|
130 |
}
|
131 |
|
132 |
if let Err(e) = shared_transcription_tx.send(segments) {
|
|
|
153 |
|
154 |
struct Detector {
|
155 |
state: WhisperState<'static>,
|
156 |
+
config: &'static WhisperConfig,
|
157 |
n_samples_keep: usize,
|
158 |
n_samples_step: usize,
|
159 |
n_samples_len: usize,
|
|
|
|
|
160 |
prompt_tokens: Vec<c_int>,
|
161 |
pcm_f32: VecDeque<f32>,
|
162 |
offset: usize,
|
163 |
+
stable_offset: usize,
|
164 |
}
|
165 |
|
166 |
impl Detector {
|
167 |
fn new(state: WhisperState<'static>,
|
168 |
+
config: &'static WhisperConfig) -> Self {
|
169 |
Detector {
|
170 |
state,
|
171 |
config,
|
172 |
n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
173 |
n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
174 |
n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
|
|
|
|
175 |
prompt_tokens: Default::default(),
|
176 |
pcm_f32: VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]),
|
177 |
offset: 0,
|
178 |
+
stable_offset: 0
|
179 |
}
|
180 |
}
|
181 |
|
182 |
fn feed(&mut self, new_pcm_f32: Vec<f32>) {
|
183 |
self.pcm_f32.extend(new_pcm_f32);
|
184 |
+
if self.pcm_f32.len() < self.n_samples_len {
|
185 |
+
return
|
186 |
}
|
187 |
+
let len_to_drain = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_len)).len();
|
188 |
+
self.offset += len_to_drain;
|
189 |
}
|
190 |
|
191 |
fn inference(&mut self) -> Result<Vec<Segment>, Error> {
|
192 |
+
let params = self.config.params.to_full_params(self.prompt_tokens.as_slice());
|
193 |
let start = std::time::Instant::now();
|
194 |
let _ = self.state.full(params, self.pcm_f32.make_contiguous())
|
195 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
196 |
let end = std::time::Instant::now();
|
197 |
if end - start > Duration::from_millis(self.config.step_ms as u64) {
|
198 |
+
tracing::warn!("full([{}]) took {} ms too slow", self.pcm_f32.len(), (end - start).as_millis());
|
199 |
}
|
200 |
|
201 |
+
let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
|
202 |
+
let stable_offset: i64 = (self.stable_offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64;
|
203 |
let num_segments = self.state
|
204 |
.full_n_segments()
|
205 |
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
206 |
let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
|
207 |
for i in 0..num_segments {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
let end_timestamp: i64 = timestamp_offset + 10 * self.state
|
209 |
.full_get_segment_t1(i)
|
210 |
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
211 |
+
if end_timestamp <= stable_offset {
|
212 |
+
continue
|
213 |
+
}
|
214 |
+
|
215 |
+
let start_timestamp: i64 = timestamp_offset + 10 * self.state
|
216 |
+
.full_get_segment_t0(i)
|
217 |
+
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
218 |
+
let segment = self.state
|
219 |
+
.full_get_segment_text(i)
|
220 |
+
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
221 |
+
let num_tokens = self.state
|
222 |
+
.full_n_tokens(i)
|
223 |
+
.map_err(|e| Error::whisper_error("failed to get segment tokens", e))?;
|
224 |
+
let mut segment_tokens = Vec::with_capacity(num_tokens as usize);
|
225 |
+
for j in 0..num_tokens {
|
226 |
+
segment_tokens.push(
|
227 |
+
self.state
|
228 |
+
.full_get_token_id(i, j)
|
229 |
+
.map_err(|e| Error::whisper_error("failed to get token", e))?
|
230 |
+
);
|
231 |
+
}
|
232 |
+
|
233 |
+
segments.push(Segment { start_timestamp, end_timestamp, text: segment.trim().to_string(), tokens: segment_tokens });
|
234 |
}
|
235 |
|
236 |
+
let Some((_last, init)) = segments.split_last() else {
|
237 |
+
return Ok(Vec::default())
|
238 |
+
};
|
239 |
|
240 |
+
let Some((last_2_seg, _)) = init.split_last() else {
|
241 |
+
return Ok(Vec::default())
|
242 |
+
};
|
243 |
|
244 |
+
let offset = (last_2_seg.end_timestamp - timestamp_offset) as usize / 1000 * WHISPER_SAMPLE_RATE as usize;
|
245 |
+
self.stable_offset = offset;
|
246 |
+
self.drop_stable_by_segments(init);
|
247 |
+
Ok(init.into())
|
|
|
|
|
248 |
}
|
249 |
|
250 |
+
fn drop_stable_by_segments(&mut self, stable_segments: &[Segment]) {
|
251 |
+
let Some(last) = stable_segments.last() else {
|
252 |
+
return
|
253 |
+
};
|
254 |
+
let drop_offset: usize = (last.end_timestamp as usize / 1000 * WHISPER_SAMPLE_RATE as usize - self.offset) as usize;
|
255 |
+
let len_to_drain = self.pcm_f32.drain(0..drop_offset).len();
|
256 |
+
self.offset += len_to_drain;
|
257 |
|
258 |
+
for segment in stable_segments.into_iter() {
|
259 |
+
self.prompt_tokens.extend(&segment.tokens);
|
|
|
260 |
}
|
261 |
+
if self.prompt_tokens.len() > self.config.max_prompt_tokens {
|
262 |
+
let _ = self.prompt_tokens.drain(0..(self.prompt_tokens.len() - self.config.max_prompt_tokens)).len();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
}
|
|
|
264 |
}
|
265 |
}
|
266 |
|