mingyang91 commited on
Commit
a4dee07
·
verified ·
1 Parent(s): ee7230e
Files changed (7) hide show
  1. Cargo.lock +5 -2
  2. Cargo.toml +4 -1
  3. config.yaml +3 -4
  4. src/config.rs +34 -16
  5. src/group.rs +59 -0
  6. src/main.rs +15 -22
  7. src/whisper.rs +199 -15
Cargo.lock CHANGED
@@ -1419,14 +1419,17 @@ dependencies = [
1419
  "aws-sdk-translate",
1420
  "clap",
1421
  "futures-util",
 
1422
  "poem",
1423
  "serde",
1424
  "serde_json",
1425
  "serde_yaml",
1426
  "tokio",
1427
  "tokio-stream",
 
1428
  "tracing-subscriber",
1429
  "whisper-rs",
 
1430
  ]
1431
 
1432
  [[package]]
@@ -2072,9 +2075,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
2072
 
2073
  [[package]]
2074
  name = "tracing"
2075
- version = "0.1.39"
2076
  source = "registry+https://github.com/rust-lang/crates.io-index"
2077
- checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9"
2078
  dependencies = [
2079
  "log",
2080
  "pin-project-lite",
 
1419
  "aws-sdk-translate",
1420
  "clap",
1421
  "futures-util",
1422
+ "lazy_static",
1423
  "poem",
1424
  "serde",
1425
  "serde_json",
1426
  "serde_yaml",
1427
  "tokio",
1428
  "tokio-stream",
1429
+ "tracing",
1430
  "tracing-subscriber",
1431
  "whisper-rs",
1432
+ "whisper-rs-sys",
1433
  ]
1434
 
1435
  [[package]]
 
2075
 
2076
  [[package]]
2077
  name = "tracing"
2078
+ version = "0.1.40"
2079
  source = "registry+https://github.com/rust-lang/crates.io-index"
2080
+ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
2081
  dependencies = [
2082
  "log",
2083
  "pin-project-lite",
Cargo.toml CHANGED
@@ -12,12 +12,15 @@ clap = { version = "4.4.6" , features = ["derive"]}
12
  tokio = { version = "1.33.0" , features = ["full"] }
13
  tokio-stream = "0.1.14"
14
  async-stream = "0.3.5"
15
- tracing-subscriber = "0.3.17"
16
  futures-util = "0.3.28"
17
  serde = { version = "1.0.189", features = ["derive"] }
18
  serde_json = { version = "1.0.107", features = [] }
19
  serde_yaml = "0.9.25"
20
  whisper-rs = { version = "0.8.0" , features = ["coreml"] }
 
 
 
 
21
 
22
  [dependencies.poem]
23
  version = "1.3.58"
 
12
  tokio = { version = "1.33.0" , features = ["full"] }
13
  tokio-stream = "0.1.14"
14
  async-stream = "0.3.5"
 
15
  futures-util = "0.3.28"
16
  serde = { version = "1.0.189", features = ["derive"] }
17
  serde_json = { version = "1.0.107", features = [] }
18
  serde_yaml = "0.9.25"
19
  whisper-rs = { version = "0.8.0" , features = ["coreml"] }
20
+ whisper-rs-sys = "0.6.1"
21
+ tracing = "0.1.40"
22
+ tracing-subscriber = "0.3.17"
23
+ lazy_static = "1.4.0"
24
 
25
  [dependencies.poem]
26
  version = "1.3.58"
config.yaml CHANGED
@@ -2,11 +2,10 @@ server:
2
  port: 8080
3
  host: ::1
4
  whisper:
5
- n_threads: 4
6
  step_ms: 500
7
  length_ms: 5000
8
- keep_ms: 200
9
- capture_id: -1
10
  max_tokens: 32
11
  audio_ctx: 0
12
  vad_thold: 0.6
@@ -15,7 +14,7 @@ whisper:
15
  translate: false
16
  no_fallback: false
17
  print_special: false
18
- no_context: true
19
  no_timestamps: false
20
  tinydiarize: false
21
  language: "en"
 
2
  port: 8080
3
  host: ::1
4
  whisper:
5
+ # n_threads: 4
6
  step_ms: 500
7
  length_ms: 5000
8
+ keep_ms: 5000
 
9
  max_tokens: 32
10
  audio_ctx: 0
11
  vad_thold: 0.6
 
14
  translate: false
15
  no_fallback: false
16
  print_special: false
17
+ no_context: false
18
  no_timestamps: false
19
  tinydiarize: false
20
  language: "en"
src/config.rs CHANGED
@@ -1,27 +1,43 @@
1
  use std::ffi::c_int;
2
- use std::fs;
3
  use std::net::IpAddr;
 
4
  use serde::{Deserialize};
5
  use whisper_rs::FullParams;
6
 
7
- #[derive(Debug, Deserialize)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  pub(crate) struct WhisperParams {
9
  pub(crate) n_threads: Option<usize>,
10
- pub(crate) step_ms: i32,
11
- pub(crate) length_ms: i32,
12
- pub(crate) keep_ms: i32,
13
- pub(crate) capture_id: i32,
14
- pub(crate) max_tokens: i32,
15
- pub(crate) audio_ctx: i32,
16
- pub(crate) vad_thold: f32,
17
- pub(crate) freq_thold: f32,
18
  pub(crate) speed_up: bool,
19
  pub(crate) translate: bool,
20
  pub(crate) no_fallback: bool,
21
  pub(crate) print_special: bool,
22
  pub(crate) no_context: bool,
23
  pub(crate) no_timestamps: bool,
24
- pub(crate) tinydiarize: bool,
25
  pub(crate) language: Option<String>,
26
  pub(crate) model: String,
27
  }
@@ -29,20 +45,20 @@ pub(crate) struct WhisperParams {
29
  const NONE: [c_int;0] = [];
30
 
31
  impl WhisperParams {
32
- pub(crate) fn to_full_params<'a, 'b>(&'a self) -> FullParams<'a, 'b> {
33
  let mut param = FullParams::new(Default::default());
34
  param.set_print_progress(false);
35
  param.set_print_special(self.print_special);
36
  param.set_print_realtime(false);
37
  param.set_print_timestamps(!self.no_timestamps);
38
  param.set_translate(self.translate);
39
- param.set_single_segment(true);
40
- param.set_max_tokens(self.max_tokens);
41
  let lang = self.language.as_ref().map(|s| s.as_str());
42
  param.set_language(lang);
43
  let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
44
  param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
45
- param.set_audio_ctx(self.audio_ctx);
46
  param.set_speed_up(self.speed_up);
47
  // param.set_tdrz_enable(self.tinydiarize);
48
  if self.no_fallback {
@@ -50,6 +66,8 @@ impl WhisperParams {
50
  }
51
  if self.no_context {
52
  param.set_tokens(&NONE);
 
 
53
  }
54
 
55
  param
@@ -63,7 +81,7 @@ pub(crate) struct Server {
63
  }
64
 
65
  #[derive(Debug, Deserialize)]
66
- pub(crate) struct Config {
67
  pub(crate) whisper: WhisperParams,
68
  pub(crate) server: Server,
69
  }
 
1
  use std::ffi::c_int;
 
2
  use std::net::IpAddr;
3
+ use lazy_static::lazy_static;
4
  use serde::{Deserialize};
5
  use whisper_rs::FullParams;
6
 
7
+ #[derive(Debug)]
8
+ pub enum Error {
9
+ IoError(std::io::Error),
10
+ ConfigError(serde_yaml::Error),
11
+ }
12
+
13
+ pub(crate) fn load_config() -> Result<Config, Error> {
14
+ let config_str = std::fs::read_to_string("config.yaml").map_err(|e| Error::IoError(e))?;
15
+ let config: Config = serde_yaml::from_str(config_str.as_str())
16
+ .map_err(|e| Error::ConfigError(e))?;
17
+ return Ok(config)
18
+ }
19
+
20
+ lazy_static! {
21
+ pub static ref CONFIG: Config = load_config().expect("failed to load config");
22
+ }
23
+
24
+ #[derive(Debug, Deserialize, Clone)]
25
  pub(crate) struct WhisperParams {
26
  pub(crate) n_threads: Option<usize>,
27
+ // pub(crate) step_ms: u32,
28
+ // pub(crate) length_ms: u32,
29
+ pub(crate) keep_ms: u32,
30
+ pub(crate) max_tokens: u32,
31
+ pub(crate) audio_ctx: u32,
32
+ // pub(crate) vad_thold: f32,
33
+ // pub(crate) freq_thold: f32,
 
34
  pub(crate) speed_up: bool,
35
  pub(crate) translate: bool,
36
  pub(crate) no_fallback: bool,
37
  pub(crate) print_special: bool,
38
  pub(crate) no_context: bool,
39
  pub(crate) no_timestamps: bool,
40
+ // pub(crate) tinydiarize: bool,
41
  pub(crate) language: Option<String>,
42
  pub(crate) model: String,
43
  }
 
45
  const NONE: [c_int;0] = [];
46
 
47
  impl WhisperParams {
48
+ pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
49
  let mut param = FullParams::new(Default::default());
50
  param.set_print_progress(false);
51
  param.set_print_special(self.print_special);
52
  param.set_print_realtime(false);
53
  param.set_print_timestamps(!self.no_timestamps);
54
  param.set_translate(self.translate);
55
+ param.set_single_segment(false);
56
+ param.set_max_tokens(self.max_tokens as i32);
57
  let lang = self.language.as_ref().map(|s| s.as_str());
58
  param.set_language(lang);
59
  let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
60
  param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
61
+ param.set_audio_ctx(self.audio_ctx as i32);
62
  param.set_speed_up(self.speed_up);
63
  // param.set_tdrz_enable(self.tinydiarize);
64
  if self.no_fallback {
 
66
  }
67
  if self.no_context {
68
  param.set_tokens(&NONE);
69
+ } else {
70
+ param.set_tokens(&tokens);
71
  }
72
 
73
  param
 
81
  }
82
 
83
  #[derive(Debug, Deserialize)]
84
+ pub struct Config {
85
  pub(crate) whisper: WhisperParams,
86
  pub(crate) server: Server,
87
  }
src/group.rs ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::time::Duration;
2
+ use tokio::{select};
3
+ use tokio::time::sleep;
4
+ use tokio::sync::mpsc::{Receiver, channel};
5
+ use tokio::sync::mpsc::error::TryRecvError;
6
+
7
+ pub struct GroupedWithin<I>
8
+ where I: 'static + Send {
9
+ outlet: Receiver<Vec<I>>
10
+ }
11
+
12
+ impl<I> GroupedWithin<I>
13
+ where I: 'static + Send {
14
+ pub fn new(group_size: usize,
15
+ window_time: Duration,
16
+ mut inlet: Receiver<Vec<I>>,
17
+ buffer: usize) -> Self {
18
+ let (tx, outlet) = channel::<Vec<I>>(buffer);
19
+ tokio::spawn(async move {
20
+ let mut window = Vec::with_capacity(group_size);
21
+
22
+ loop {
23
+ let grouped_fut = async {
24
+ while let Some(c) = inlet.recv().await {
25
+ window.extend(c);
26
+ if window.len() > group_size {
27
+ let will_send: Vec<I> = window.drain(0..group_size).collect();
28
+ return Some(will_send)
29
+ }
30
+ }
31
+ return None
32
+ };
33
+
34
+ let grouped = select! {
35
+ _ = sleep(window_time) => {
36
+ window.drain(..).collect()
37
+ },
38
+ grouped_opt = grouped_fut => {
39
+ match grouped_opt {
40
+ None => break,
41
+ Some(grouped) => grouped
42
+ }
43
+ }
44
+ };
45
+
46
+ if let Err(e) = tx.send(grouped).await {
47
+ tracing::error!("{}", e);
48
+ }
49
+ }
50
+ });
51
+ Self {
52
+ outlet
53
+ }
54
+ }
55
+
56
+ pub fn next(&mut self) -> Result<Vec<I>, TryRecvError> {
57
+ self.outlet.try_recv()
58
+ }
59
+ }
src/main.rs CHANGED
@@ -17,17 +17,17 @@ use poem::web::websocket::{Message, WebSocket};
17
  use futures_util::stream::StreamExt;
18
  use poem::web::{Data, Query};
19
 
20
- use tokio::{fs, select};
21
  use serde::{Deserialize, Serialize};
22
- use whisper_rs::WhisperContext;
23
  use lesson::{LessonsManager};
24
- use crate::config::Config;
25
  use crate::lesson::Viseme;
26
- use crate::whisper::run_whisper;
27
 
28
  mod lesson;
29
  mod config;
30
  mod whisper;
 
31
 
32
 
33
  #[derive(Debug, Parser)]
@@ -50,26 +50,11 @@ struct Context {
50
  lessons_manager: LessonsManager,
51
  }
52
 
53
- #[derive(Debug)]
54
- enum Error {
55
- IoError(std::io::Error),
56
- ConfigError(serde_yaml::Error),
57
- }
58
-
59
- async fn load_config() -> Result<Config, Error> {
60
- let config_str = fs::read_to_string("config.yaml").await.map_err(|e| Error::IoError(e))?;
61
- let config: Config = serde_yaml::from_str(config_str.as_str())
62
- .map_err(|e| Error::ConfigError(e))?;
63
- return Ok(config)
64
- }
65
 
66
  #[tokio::main]
67
  async fn main() -> Result<(), std::io::Error> {
68
  tracing_subscriber::fmt::init();
69
 
70
- let config = load_config().await.expect("failed to load config");
71
- run_whisper(&config).await;
72
-
73
  let Opt {
74
  region,
75
  verbose,
@@ -107,7 +92,8 @@ async fn main() -> Result<(), std::io::Error> {
107
  .at("lesson-speaker", StaticFileEndpoint::new("./static/index.html"))
108
  .at("lesson-listener", StaticFileEndpoint::new("./static/index.html"))
109
  .data(ctx);
110
- let listener = TcpListener::bind("[::]:8080");
 
111
  let server = Server::new(listener);
112
 
113
  server.run(app).await
@@ -127,11 +113,20 @@ async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, w
127
  ws.on_upgrade(|mut socket| async move {
128
  let origin_tx = lesson.voice_channel();
129
  let mut transcribe_rx = lesson.transcript_channel();
 
 
130
  loop {
131
  select! {
 
 
 
 
 
 
132
  msg = socket.next() => {
133
  match msg.as_ref() {
134
  Some(Ok(Message::Binary(bin))) => {
 
135
  if origin_tx.send(bin.to_vec()).await.is_err() {
136
  println!("tx closed");
137
  break;
@@ -217,7 +212,6 @@ async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>,
217
  },
218
  voice = voice_rx.recv() => {
219
  if let Ok(voice) = voice {
220
- println!("Synthesized: {:?}", voice.len());
221
  let _ = socket.send(Message::Binary(voice)).await;
222
  }
223
  },
@@ -225,7 +219,6 @@ async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>,
225
  if let Ok(visemes) = visemes {
226
  let evt = LiveLessonTextEvent::LipSync { visemes };
227
  let json = serde_json::to_string(&evt).expect("failed to serialize");
228
- println!("Visemes: {:?}", json);
229
  let _ = socket.send(Message::Text(json)).await;
230
  }
231
  },
 
17
  use futures_util::stream::StreamExt;
18
  use poem::web::{Data, Query};
19
 
20
+ use tokio::{select};
21
  use serde::{Deserialize, Serialize};
 
22
  use lesson::{LessonsManager};
23
+ use crate::config::CONFIG;
24
  use crate::lesson::Viseme;
25
+ use crate::whisper::WhisperHandler;
26
 
27
  mod lesson;
28
  mod config;
29
  mod whisper;
30
+ mod group;
31
 
32
 
33
  #[derive(Debug, Parser)]
 
50
  lessons_manager: LessonsManager,
51
  }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  #[tokio::main]
55
  async fn main() -> Result<(), std::io::Error> {
56
  tracing_subscriber::fmt::init();
57
 
 
 
 
58
  let Opt {
59
  region,
60
  verbose,
 
92
  .at("lesson-speaker", StaticFileEndpoint::new("./static/index.html"))
93
  .at("lesson-listener", StaticFileEndpoint::new("./static/index.html"))
94
  .data(ctx);
95
+ let addr = format!("{}:{}", CONFIG.server.host, CONFIG.server.port);
96
+ let listener = TcpListener::bind(addr);
97
  let server = Server::new(listener);
98
 
99
  server.run(app).await
 
113
  ws.on_upgrade(|mut socket| async move {
114
  let origin_tx = lesson.voice_channel();
115
  let mut transcribe_rx = lesson.transcript_channel();
116
+ let whisper = WhisperHandler::new(CONFIG.whisper.clone()).expect("failed to create whisper");
117
+ let mut whisper_transcribe_rx = whisper.subscribe();
118
  loop {
119
  select! {
120
+ w = whisper_transcribe_rx.recv() => {
121
+ let Ok(txt) = w else {
122
+ continue
123
+ };
124
+ println!("Whisper: {:?}", txt)
125
+ }
126
  msg = socket.next() => {
127
  match msg.as_ref() {
128
  Some(Ok(Message::Binary(bin))) => {
129
+ let _ = whisper.send(bin.clone()).await; // whisper test
130
  if origin_tx.send(bin.to_vec()).await.is_err() {
131
  println!("tx closed");
132
  break;
 
212
  },
213
  voice = voice_rx.recv() => {
214
  if let Ok(voice) = voice {
 
215
  let _ = socket.send(Message::Binary(voice)).await;
216
  }
217
  },
 
219
  if let Ok(visemes) = visemes {
220
  let evt = LiveLessonTextEvent::LipSync { visemes };
221
  let json = serde_json::to_string(&evt).expect("failed to serialize");
 
222
  let _ = socket.send(Message::Text(json)).await;
223
  }
224
  },
src/whisper.rs CHANGED
@@ -1,14 +1,57 @@
1
- use whisper_rs::WhisperContext;
2
- use crate::config::Config;
 
 
 
 
 
 
 
 
3
 
4
- pub(crate) async fn run_whisper(config: &Config) {
5
- let ctx = WhisperContext::new(&*config.whisper.model).expect("failed to load whisper context");
6
- let mut _state = ctx.create_state().expect("failed to create state");
7
- let params = (&config.whisper).to_full_params();
8
- _state.full(params, &[]).expect("TODO: panic message");
9
  }
10
 
11
- async fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  let pcm_i16 = input
13
  .chunks_exact(2)
14
  .map(|chunk| {
@@ -17,13 +60,154 @@ async fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
17
  i16::from_le_bytes(buf)
18
  })
19
  .collect::<Vec<i16>>();
20
- let pcm_f32 = pcm_i16
21
- .iter()
22
- .map(|i| *i as f32 / i16::MAX as f32)
23
- .collect::<Vec<f32>>();
24
- pcm_f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  }
26
 
27
- struct WhisperHandler {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::ffi::c_int;
2
+ use std::fmt::{Debug, Display, Formatter};
3
+ use std::thread::sleep;
4
+ use std::time::Duration;
5
+ use lazy_static::lazy_static;
6
+ use tokio::sync::{broadcast, mpsc, oneshot};
7
+ use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState};
8
+ use whisper_rs_sys::WHISPER_SAMPLE_RATE;
9
+ use crate::config::{WhisperParams, CONFIG};
10
+ use crate::group::GroupedWithin;
11
 
12
+ lazy_static! {
13
+ static ref WHISPER_CONTEXT: WhisperContext = {
14
+ WhisperContext::new(&*CONFIG.whisper.model)
15
+ .expect("failed to create WhisperContext")
16
+ };
17
  }
18
 
19
+ #[derive(Debug)]
20
+ pub(crate) enum Error {
21
+ WhisperError {
22
+ description: String,
23
+ error: whisper_rs::WhisperError,
24
+ },
25
+ }
26
+
27
+ impl Error {
28
+ fn whisper_error(description: &str, error: whisper_rs::WhisperError) -> Self {
29
+ Self::WhisperError {
30
+ description: description.to_string(),
31
+ error,
32
+ }
33
+ }
34
+ }
35
+
36
+ impl Display for Error {
37
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38
+ match self {
39
+ Self::WhisperError { description, error } => {
40
+ write!(f, "WhisperError: {}: {}", description, error)
41
+ }
42
+ }
43
+ }
44
+ }
45
+
46
+ impl std::error::Error for Error {
47
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
48
+ match self {
49
+ Self::WhisperError { error, .. } => Some(error),
50
+ }
51
+ }
52
+ }
53
+
54
+ fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
55
  let pcm_i16 = input
56
  .chunks_exact(2)
57
  .map(|chunk| {
 
60
  i16::from_le_bytes(buf)
61
  })
62
  .collect::<Vec<i16>>();
63
+ convert_integer_to_float_audio(pcm_i16.as_slice())
64
+ }
65
+
66
+ #[derive(Clone, Debug)]
67
+ pub struct Segment {
68
+ pub start_timestamp: i64,
69
+ pub end_timestamp: i64,
70
+ pub text: String,
71
+ }
72
+
73
+ pub struct WhisperHandler {
74
+ tx: mpsc::Sender<Vec<u8>>,
75
+ transcription_tx: broadcast::Sender<Vec<Segment>>,
76
+ stop_handle: Option<oneshot::Sender<()>>,
77
+ }
78
+
79
+ impl WhisperHandler {
80
+ pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
81
+ let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
82
+ let (stop_handle, mut stop_signal) = oneshot::channel();
83
+ let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
84
+ let mut grouped = GroupedWithin::new(
85
+ n_samples_keep,
86
+ Duration::from_secs(5),
87
+ pcm_rx,
88
+ 1024
89
+ );
90
+ let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
91
+ let shared_transcription_tx = transcription_tx.clone();
92
+ let mut state = WHISPER_CONTEXT.create_state()
93
+ .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
94
+ tokio::task::spawn_blocking(move || {
95
+ let mut tokens: Vec<c_int> = Default::default();
96
+ let mut pcm_f32: Vec<f32> = Default::default();
97
+ while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
98
+ let new_pcm_f32 = match grouped.next() {
99
+ Err(mpsc::error::TryRecvError::Disconnected) => break,
100
+ Err(mpsc::error::TryRecvError::Empty) => {
101
+ sleep(Duration::from_millis(10));
102
+ continue
103
+ }
104
+ Ok(data) => {
105
+ pcm_i16_to_f32(&data)
106
+ }
107
+ };
108
+
109
+ pcm_f32.extend(new_pcm_f32);
110
+ match inference(&mut state, &config, n_samples_keep, &mut tokens, &mut pcm_f32) {
111
+ Ok(segments) => {
112
+ if segments.is_empty() {
113
+ continue
114
+ }
115
+ if let Err(e) = shared_transcription_tx.send(segments) {
116
+ tracing::error!("failed to send transcription: {}", e);
117
+ break
118
+ }
119
+ }
120
+ Err(err) => {
121
+ tracing::error!("failed to run whisper: {}", err);
122
+ continue
123
+ // break
124
+ }
125
+ }
126
+ }
127
+ });
128
+ Ok(Self {
129
+ tx: pcm_tx,
130
+ transcription_tx,
131
+ stop_handle: Some(stop_handle),
132
+ })
133
+ }
134
+
135
+ pub fn subscribe(&self) -> broadcast::Receiver<Vec<Segment>> {
136
+ self.transcription_tx.subscribe()
137
+ }
138
+
139
+ pub async fn send(&self, data: Vec<u8>) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
140
+ self.tx.send(data).await
141
+ }
142
  }
143
 
144
+ fn inference(
145
+ state: &mut WhisperState,
146
+ config: &WhisperParams,
147
+ n_samples_keep: usize,
148
+ prompt_tokens: &mut Vec<c_int>,
149
+ pcm_f32: &mut Vec<f32>
150
+ ) -> Result<Vec<Segment>, Error> {
151
+ let params = config.to_full_params(prompt_tokens.as_slice());
152
+
153
+ let st = std::time::Instant::now();
154
+ let _ = state.full(params, pcm_f32.as_slice())
155
+ .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
156
+ let et = std::time::Instant::now();
157
+
158
+ let num_segments = state
159
+ .full_n_segments()
160
+ .map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
161
+ let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
162
+ for i in 0..num_segments {
163
+ let segment = state
164
+ .full_get_segment_text(i)
165
+ .map_err(|e| Error::whisper_error("failed to get segment", e))?;
166
+ let start_timestamp = state
167
+ .full_get_segment_t0(i)
168
+ .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
169
+ let end_timestamp = state
170
+ .full_get_segment_t1(i)
171
+ .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
172
+ tracing::debug!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
173
+ segments.push(Segment { start_timestamp, end_timestamp, text: segment });
174
+ }
175
 
176
+ if !config.no_context {
177
+ prompt_tokens.clear();
178
+
179
+ // keep the last n_samples_keep samples from pcm_f32
180
+ if pcm_f32.len() > n_samples_keep {
181
+ let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep)).collect::<Vec<_>>();
182
+ }
183
+
184
+ let n_segments = state
185
+ .full_n_segments()
186
+ .map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
187
+ for i in 0..n_segments {
188
+ let token_count = state
189
+ .full_n_tokens(i)
190
+ .map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
191
+ for j in 0..token_count {
192
+ let token = state
193
+ .full_get_token_id(i, j)
194
+ .map_err(|e| Error::whisper_error("failed to get token", e))?;
195
+ prompt_tokens.push(token);
196
+ }
197
+ }
198
+ }
199
+
200
+ tracing::trace!("took {}ms", (et - st).as_millis());
201
+ Ok(segments)
202
+ }
203
+
204
+ impl Drop for WhisperHandler {
205
+ fn drop(&mut self) {
206
+ let Some(stop_handle) = self.stop_handle.take() else {
207
+ return tracing::warn!("WhisperHandler::drop() called without stop_handle");
208
+ };
209
+ if let Err(_) = stop_handle.send(()) {
210
+ tracing::warn!("WhisperHandler::drop() failed to send stop signal");
211
+ }
212
+ }
213
+ }