Matrix commited on
Commit
3569cbd
·
1 Parent(s): 9ce120f

chore: use once_cell lazy instead of lazy_static

Browse files
Files changed (7) hide show
  1. Cargo.lock +1 -1
  2. Cargo.toml +1 -1
  3. src/config.rs +6 -20
  4. src/group.rs +1 -1
  5. src/lesson.rs +3 -3
  6. src/main.rs +16 -14
  7. src/whisper.rs +29 -25
Cargo.lock CHANGED
@@ -1449,7 +1449,7 @@ dependencies = [
1449
  "aws-sdk-translate",
1450
  "config",
1451
  "futures-util",
1452
- "lazy_static",
1453
  "poem",
1454
  "serde",
1455
  "serde_json",
 
1449
  "aws-sdk-translate",
1450
  "config",
1451
  "futures-util",
1452
+ "once_cell",
1453
  "poem",
1454
  "serde",
1455
  "serde_json",
Cargo.toml CHANGED
@@ -12,7 +12,7 @@ aws-sdk-translate = "0.34"
12
  aws-sdk-polly = "0.34"
13
  config = "0.13"
14
  futures-util = "0.3"
15
- lazy_static = "1.4"
16
  serde = { version = "1.0", features = ["derive"] }
17
  serde_json = "1.0"
18
  serde_yaml = "0.9"
 
12
  aws-sdk-polly = "0.34"
13
  config = "0.13"
14
  futures-util = "0.3"
15
+ once_cell = "1.18"
16
  serde = { version = "1.0", features = ["derive"] }
17
  serde_json = "1.0"
18
  serde_yaml = "0.9"
src/config.rs CHANGED
@@ -1,26 +1,12 @@
1
- use std::{ffi::c_int, net::IpAddr, env};
2
 
3
- use config::{Environment, Config, File};
4
- use lazy_static::lazy_static;
5
  use serde::Deserialize;
6
  use whisper_rs::FullParams;
7
 
8
- #[derive(Debug)]
9
- pub enum Error {
10
- IoError(std::io::Error),
11
- ConfigError(serde_yaml::Error),
12
- }
13
-
14
- pub(crate) fn load_config() -> Result<Settings, Error> {
15
- let config_str = std::fs::read_to_string("../config/dev.yaml").map_err(|e| Error::IoError(e))?;
16
- let config: Settings =
17
- serde_yaml::from_str(config_str.as_str()).map_err(|e| Error::ConfigError(e))?;
18
- return Ok(config);
19
- }
20
-
21
- lazy_static! {
22
- pub static ref CONFIG: Settings = load_config().expect("failed to load config");
23
- }
24
 
25
  #[derive(Debug, Deserialize, Clone)]
26
  pub(crate) struct WhisperConfig {
@@ -62,7 +48,7 @@ impl WhisperParams {
62
  param.set_translate(self.translate);
63
  param.set_single_segment(false);
64
  param.set_max_tokens(self.max_tokens as i32);
65
- let lang = self.language.as_ref().map(|s| s.as_str());
66
  param.set_language(lang);
67
  let num_cpus = std::thread::available_parallelism()
68
  .map(|c| c.get())
 
1
+ use std::{env, ffi::c_int, net::IpAddr};
2
 
3
+ use config::{Config, Environment, File};
4
+ use once_cell::sync::Lazy;
5
  use serde::Deserialize;
6
  use whisper_rs::FullParams;
7
 
8
+ pub(crate) static SETTINGS: Lazy<Settings> =
9
+ Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  #[derive(Debug, Deserialize, Clone)]
12
  pub(crate) struct WhisperConfig {
 
48
  param.set_translate(self.translate);
49
  param.set_single_segment(false);
50
  param.set_max_tokens(self.max_tokens as i32);
51
+ let lang = self.language.as_deref();
52
  param.set_language(lang);
53
  let num_cpus = std::thread::available_parallelism()
54
  .map(|c| c.get())
src/group.rs CHANGED
@@ -35,7 +35,7 @@ where
35
  return Some(will_send);
36
  }
37
  }
38
- return None;
39
  };
40
 
41
  let grouped: Vec<I> = select! {
 
35
  return Some(will_send);
36
  }
37
  }
38
+ None
39
  };
40
 
41
  let grouped: Vec<I> = select! {
src/lesson.rs CHANGED
@@ -29,9 +29,9 @@ pub struct LessonsManager {
29
 
30
  impl LessonsManager {
31
  pub(crate) fn new(sdk_config: &SdkConfig) -> Self {
32
- let transcript_client = aws_sdk_transcribestreaming::Client::new(&sdk_config);
33
- let translate_client = aws_sdk_translate::Client::new(&sdk_config);
34
- let polly_client = aws_sdk_polly::Client::new(&sdk_config);
35
  LessonsManager {
36
  translate_client,
37
  polly_client,
 
29
 
30
  impl LessonsManager {
31
  pub(crate) fn new(sdk_config: &SdkConfig) -> Self {
32
+ let transcript_client = aws_sdk_transcribestreaming::Client::new(sdk_config);
33
+ let translate_client = aws_sdk_translate::Client::new(sdk_config);
34
+ let polly_client = aws_sdk_polly::Client::new(sdk_config);
35
  LessonsManager {
36
  translate_client,
37
  polly_client,
src/main.rs CHANGED
@@ -5,21 +5,23 @@
5
 
6
  #![allow(clippy::result_large_err)]
7
 
8
- use aws_sdk_transcribestreaming::{meta::PKG_VERSION};
9
- use futures_util::stream::StreamExt;
10
- use futures_util::SinkExt;
11
- use poem::endpoint::{StaticFileEndpoint, StaticFilesEndpoint};
12
- use poem::web::websocket::{Message, WebSocket};
13
- use poem::web::{Data, Query};
14
- use poem::{get, handler, listener::TcpListener, EndpointExt, IntoResponse, Route, Server};
15
-
16
- use crate::config::CONFIG;
17
- use crate::lesson::Viseme;
18
- use crate::whisper::WhisperHandler;
19
- use lesson::LessonsManager;
20
  use serde::{Deserialize, Serialize};
21
  use tokio::select;
22
 
 
 
23
  mod config;
24
  mod group;
25
  mod lesson;
@@ -61,7 +63,7 @@ async fn main() -> Result<(), std::io::Error> {
61
  StaticFileEndpoint::new("./static/index.html"),
62
  )
63
  .data(ctx);
64
- let addr = format!("{}:{}", CONFIG.server.host, CONFIG.server.port);
65
  let listener = TcpListener::bind(addr);
66
  let server = Server::new(listener);
67
 
@@ -94,7 +96,7 @@ async fn stream_speaker(
94
  let _origin_tx = lesson.voice_channel();
95
  let mut transcribe_rx = lesson.transcript_channel();
96
  let whisper =
97
- WhisperHandler::new(CONFIG.whisper.clone(), prompt).expect("failed to create whisper");
98
  let mut whisper_transcribe_rx = whisper.subscribe();
99
  loop {
100
  select! {
 
5
 
6
  #![allow(clippy::result_large_err)]
7
 
8
+ use aws_sdk_transcribestreaming::meta::PKG_VERSION;
9
+ use futures_util::{stream::StreamExt, SinkExt};
10
+ use poem::{
11
+ endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
12
+ get, handler,
13
+ listener::TcpListener,
14
+ web::{
15
+ websocket::{Message, WebSocket},
16
+ Data, Query,
17
+ },
18
+ EndpointExt, IntoResponse, Route, Server,
19
+ };
20
  use serde::{Deserialize, Serialize};
21
  use tokio::select;
22
 
23
+ use crate::{config::*, lesson::*, whisper::*};
24
+
25
  mod config;
26
  mod group;
27
  mod lesson;
 
63
  StaticFileEndpoint::new("./static/index.html"),
64
  )
65
  .data(ctx);
66
+ let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
67
  let listener = TcpListener::bind(addr);
68
  let server = Server::new(listener);
69
 
 
96
  let _origin_tx = lesson.voice_channel();
97
  let mut transcribe_rx = lesson.transcript_channel();
98
  let whisper =
99
+ WhisperHandler::new(SETTINGS.whisper.clone(), prompt).expect("failed to create whisper");
100
  let mut whisper_transcribe_rx = whisper.subscribe();
101
  loop {
102
  select! {
src/whisper.rs CHANGED
@@ -1,19 +1,23 @@
1
- use crate::config::{WhisperConfig, CONFIG};
2
- use crate::group::GroupedWithin;
3
- use lazy_static::lazy_static;
4
- use std::collections::VecDeque;
5
- use std::ffi::c_int;
6
- use std::fmt::{Debug, Display, Formatter};
7
- use std::thread::sleep;
8
- use std::time::Duration;
 
9
  use tokio::sync::{broadcast, mpsc, oneshot};
10
  use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken};
11
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
12
 
13
- lazy_static! {
14
- static ref WHISPER_CONTEXT: WhisperContext =
15
- WhisperContext::new(&*CONFIG.whisper.model).expect("failed to create WhisperContext");
16
- }
 
 
 
17
 
18
  #[derive(Debug)]
19
  pub(crate) enum Error {
@@ -86,10 +90,10 @@ impl WhisperHandler {
86
  .create_state()
87
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
88
  let preset_prompt_tokens = WHISPER_CONTEXT
89
- .tokenize(prompt.as_str(), CONFIG.whisper.max_prompt_tokens)
90
  .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
91
  tokio::task::spawn_blocking(move || {
92
- let mut detector = Detector::new(state, &CONFIG.whisper, preset_prompt_tokens);
93
  let mut grouped = GroupedWithin::new(
94
  detector.n_samples_step * 2,
95
  Duration::from_millis(config.step_ms as u64),
@@ -203,7 +207,7 @@ impl Detector {
203
  self.preset_prompt_tokens.as_slice(),
204
  self.prompt_tokens.as_slice(),
205
  ]
206
- .concat();
207
  let params = self.config.params.to_full_params(prompt_tokens.as_slice());
208
  let start = std::time::Instant::now();
209
  let _ = self
@@ -229,18 +233,18 @@ impl Detector {
229
  for i in 0..num_segments {
230
  let end_timestamp: i64 = timestamp_offset
231
  + 10 * self
232
- .state
233
- .full_get_segment_t1(i)
234
- .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
235
  if end_timestamp <= stable_offset {
236
  continue;
237
  }
238
 
239
  let start_timestamp: i64 = timestamp_offset
240
  + 10 * self
241
- .state
242
- .full_get_segment_t0(i)
243
- .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
244
  let segment = self
245
  .state
246
  .full_get_segment_text(i)
@@ -285,12 +289,12 @@ impl Detector {
285
  let Some(last) = stable_segments.last() else {
286
  return;
287
  };
288
- let drop_offset: usize = (last.end_timestamp as usize / 1000 * WHISPER_SAMPLE_RATE as usize
289
- - self.offset) as usize;
290
  let len_to_drain = self.pcm_f32.drain(0..drop_offset).len();
291
  self.offset += len_to_drain;
292
 
293
- for segment in stable_segments.into_iter() {
294
  self.prompt_tokens.extend(&segment.tokens);
295
  }
296
  if self.prompt_tokens.len() > self.config.max_prompt_tokens {
@@ -307,7 +311,7 @@ impl Drop for WhisperHandler {
307
  let Some(stop_handle) = self.stop_handle.take() else {
308
  return tracing::warn!("WhisperHandler::drop() called without stop_handle");
309
  };
310
- if let Err(_) = stop_handle.send(()) {
311
  tracing::warn!("WhisperHandler::drop() failed to send stop signal");
312
  }
313
  }
 
1
+ use std::{
2
+ collections::VecDeque,
3
+ ffi::c_int,
4
+ fmt::{Debug, Display, Formatter},
5
+ thread::sleep,
6
+ time::Duration,
7
+ };
8
+
9
+ use once_cell::sync::Lazy;
10
  use tokio::sync::{broadcast, mpsc, oneshot};
11
  use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken};
12
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
13
 
14
+ use crate::config::{Settings, SETTINGS};
15
+ use crate::{config::WhisperConfig, group::GroupedWithin};
16
+
17
+ static WHISPER_CONTEXT: Lazy<WhisperContext> = Lazy::new(|| {
18
+ let settings = Settings::new().expect("Failed to initialize settings.");
19
+ WhisperContext::new(&settings.whisper.model).expect("failed to create WhisperContext")
20
+ });
21
 
22
  #[derive(Debug)]
23
  pub(crate) enum Error {
 
90
  .create_state()
91
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
92
  let preset_prompt_tokens = WHISPER_CONTEXT
93
+ .tokenize(prompt.as_str(), SETTINGS.whisper.max_prompt_tokens)
94
  .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
95
  tokio::task::spawn_blocking(move || {
96
+ let mut detector = Detector::new(state, &SETTINGS.whisper, preset_prompt_tokens);
97
  let mut grouped = GroupedWithin::new(
98
  detector.n_samples_step * 2,
99
  Duration::from_millis(config.step_ms as u64),
 
207
  self.preset_prompt_tokens.as_slice(),
208
  self.prompt_tokens.as_slice(),
209
  ]
210
+ .concat();
211
  let params = self.config.params.to_full_params(prompt_tokens.as_slice());
212
  let start = std::time::Instant::now();
213
  let _ = self
 
233
  for i in 0..num_segments {
234
  let end_timestamp: i64 = timestamp_offset
235
  + 10 * self
236
+ .state
237
+ .full_get_segment_t1(i)
238
+ .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
239
  if end_timestamp <= stable_offset {
240
  continue;
241
  }
242
 
243
  let start_timestamp: i64 = timestamp_offset
244
  + 10 * self
245
+ .state
246
+ .full_get_segment_t0(i)
247
+ .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
248
  let segment = self
249
  .state
250
  .full_get_segment_text(i)
 
289
  let Some(last) = stable_segments.last() else {
290
  return;
291
  };
292
+ let drop_offset: usize =
293
+ last.end_timestamp as usize / 1000 * WHISPER_SAMPLE_RATE as usize - self.offset;
294
  let len_to_drain = self.pcm_f32.drain(0..drop_offset).len();
295
  self.offset += len_to_drain;
296
 
297
+ for segment in stable_segments.iter() {
298
  self.prompt_tokens.extend(&segment.tokens);
299
  }
300
  if self.prompt_tokens.len() > self.config.max_prompt_tokens {
 
311
  let Some(stop_handle) = self.stop_handle.take() else {
312
  return tracing::warn!("WhisperHandler::drop() called without stop_handle");
313
  };
314
+ if stop_handle.send(()).is_err() {
315
  tracing::warn!("WhisperHandler::drop() failed to send stop signal");
316
  }
317
  }