mingyang91 commited on
Commit
ee7230e
·
verified ·
1 Parent(s): b0c7520

prepare to integrate whisper

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. Cargo.lock +20 -0
  3. Cargo.toml +1 -0
  4. config.yaml +22 -0
  5. src/config.rs +76 -0
  6. src/main.rs +21 -4
  7. src/whisper.rs +29 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  /target
2
  .idea/
 
 
1
  /target
2
  .idea/
3
+ models
Cargo.lock CHANGED
@@ -1422,6 +1422,7 @@ dependencies = [
1422
  "poem",
1423
  "serde",
1424
  "serde_json",
 
1425
  "tokio",
1426
  "tokio-stream",
1427
  "tracing-subscriber",
@@ -1740,6 +1741,19 @@ dependencies = [
1740
  "serde",
1741
  ]
1742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1743
  [[package]]
1744
  name = "sha1"
1745
  version = "0.10.6"
@@ -2184,6 +2198,12 @@ dependencies = [
2184
  "tinyvec",
2185
  ]
2186
 
 
 
 
 
 
 
2187
  [[package]]
2188
  name = "untrusted"
2189
  version = "0.7.1"
 
1422
  "poem",
1423
  "serde",
1424
  "serde_json",
1425
+ "serde_yaml",
1426
  "tokio",
1427
  "tokio-stream",
1428
  "tracing-subscriber",
 
1741
  "serde",
1742
  ]
1743
 
1744
+ [[package]]
1745
+ name = "serde_yaml"
1746
+ version = "0.9.25"
1747
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1748
+ checksum = "1a49e178e4452f45cb61d0cd8cebc1b0fafd3e41929e996cef79aa3aca91f574"
1749
+ dependencies = [
1750
+ "indexmap 2.0.2",
1751
+ "itoa",
1752
+ "ryu",
1753
+ "serde",
1754
+ "unsafe-libyaml",
1755
+ ]
1756
+
1757
  [[package]]
1758
  name = "sha1"
1759
  version = "0.10.6"
 
2198
  "tinyvec",
2199
  ]
2200
 
2201
+ [[package]]
2202
+ name = "unsafe-libyaml"
2203
+ version = "0.2.9"
2204
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2205
+ checksum = "f28467d3e1d3c6586d8f25fa243f544f5800fec42d97032474e17222c2b75cfa"
2206
+
2207
  [[package]]
2208
  name = "untrusted"
2209
  version = "0.7.1"
Cargo.toml CHANGED
@@ -16,6 +16,7 @@ tracing-subscriber = "0.3.17"
16
  futures-util = "0.3.28"
17
  serde = { version = "1.0.189", features = ["derive"] }
18
  serde_json = { version = "1.0.107", features = [] }
 
19
  whisper-rs = { version = "0.8.0" , features = ["coreml"] }
20
 
21
  [dependencies.poem]
 
16
  futures-util = "0.3.28"
17
  serde = { version = "1.0.189", features = ["derive"] }
18
  serde_json = { version = "1.0.107", features = [] }
19
+ serde_yaml = "0.9.25"
20
  whisper-rs = { version = "0.8.0" , features = ["coreml"] }
21
 
22
  [dependencies.poem]
config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ server:
2
+ port: 8080
3
+ host: ::1
4
+ whisper:
5
+ n_threads: 4
6
+ step_ms: 500
7
+ length_ms: 5000
8
+ keep_ms: 200
9
+ capture_id: -1
10
+ max_tokens: 32
11
+ audio_ctx: 0
12
+ vad_thold: 0.6
13
+ freq_thold: 100.0
14
+ speed_up: false
15
+ translate: false
16
+ no_fallback: false
17
+ print_special: false
18
+ no_context: true
19
+ no_timestamps: false
20
+ tinydiarize: false
21
+ language: "en"
22
+ model: "models/ggml-base.bin"
src/config.rs ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::ffi::c_int;
2
+ use std::fs;
3
+ use std::net::IpAddr;
4
+ use serde::{Deserialize};
5
+ use whisper_rs::FullParams;
6
+
7
+ #[derive(Debug, Deserialize)]
8
+ pub(crate) struct WhisperParams {
9
+ pub(crate) n_threads: Option<usize>,
10
+ pub(crate) step_ms: i32,
11
+ pub(crate) length_ms: i32,
12
+ pub(crate) keep_ms: i32,
13
+ pub(crate) capture_id: i32,
14
+ pub(crate) max_tokens: i32,
15
+ pub(crate) audio_ctx: i32,
16
+ pub(crate) vad_thold: f32,
17
+ pub(crate) freq_thold: f32,
18
+ pub(crate) speed_up: bool,
19
+ pub(crate) translate: bool,
20
+ pub(crate) no_fallback: bool,
21
+ pub(crate) print_special: bool,
22
+ pub(crate) no_context: bool,
23
+ pub(crate) no_timestamps: bool,
24
+ pub(crate) tinydiarize: bool,
25
+ pub(crate) language: Option<String>,
26
+ pub(crate) model: String,
27
+ }
28
+
29
+ const NONE: [c_int;0] = [];
30
+
31
+ impl WhisperParams {
32
+ pub(crate) fn to_full_params<'a, 'b>(&'a self) -> FullParams<'a, 'b> {
33
+ let mut param = FullParams::new(Default::default());
34
+ param.set_print_progress(false);
35
+ param.set_print_special(self.print_special);
36
+ param.set_print_realtime(false);
37
+ param.set_print_timestamps(!self.no_timestamps);
38
+ param.set_translate(self.translate);
39
+ param.set_single_segment(true);
40
+ param.set_max_tokens(self.max_tokens);
41
+ let lang = self.language.as_ref().map(|s| s.as_str());
42
+ param.set_language(lang);
43
+ let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
44
+ param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
45
+ param.set_audio_ctx(self.audio_ctx);
46
+ param.set_speed_up(self.speed_up);
47
+ // param.set_tdrz_enable(self.tinydiarize);
48
+ if self.no_fallback {
49
+ param.set_temperature_inc(0.0);
50
+ }
51
+ if self.no_context {
52
+ param.set_tokens(&NONE);
53
+ }
54
+
55
+ param
56
+ }
57
+ }
58
+
59
+ #[derive(Debug, Deserialize)]
60
+ pub(crate) struct Server {
61
+ pub(crate) port: u16,
62
+ pub(crate) host: IpAddr,
63
+ }
64
+
65
+ #[derive(Debug, Deserialize)]
66
+ pub(crate) struct Config {
67
+ pub(crate) whisper: WhisperParams,
68
+ pub(crate) server: Server,
69
+ }
70
+
71
+ #[tokio::test]
72
+ async fn load() {
73
+ let config_str = fs::read_to_string("config.yaml").expect("failed to read config file");
74
+ let params: Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
75
+ println!("{:?}", params);
76
+ }
src/main.rs CHANGED
@@ -17,13 +17,17 @@ use poem::web::websocket::{Message, WebSocket};
17
  use futures_util::stream::StreamExt;
18
  use poem::web::{Data, Query};
19
 
20
- use tokio::select;
21
  use serde::{Deserialize, Serialize};
22
  use whisper_rs::WhisperContext;
23
  use lesson::{LessonsManager};
 
24
  use crate::lesson::Viseme;
 
25
 
26
  mod lesson;
 
 
27
 
28
 
29
  #[derive(Debug, Parser)]
@@ -46,12 +50,25 @@ struct Context {
46
  lessons_manager: LessonsManager,
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  #[tokio::main]
50
  async fn main() -> Result<(), std::io::Error> {
51
  tracing_subscriber::fmt::init();
52
- let wc = WhisperContext::new("/Users/famer.me/Downloads/ggml-base.bin");
53
- let wc = wc.expect("failed to load whisper context");
54
- let _ = wc.create_state().expect("failed to create state");
55
 
56
  let Opt {
57
  region,
 
17
  use futures_util::stream::StreamExt;
18
  use poem::web::{Data, Query};
19
 
20
+ use tokio::{fs, select};
21
  use serde::{Deserialize, Serialize};
22
  use whisper_rs::WhisperContext;
23
  use lesson::{LessonsManager};
24
+ use crate::config::Config;
25
  use crate::lesson::Viseme;
26
+ use crate::whisper::run_whisper;
27
 
28
  mod lesson;
29
+ mod config;
30
+ mod whisper;
31
 
32
 
33
  #[derive(Debug, Parser)]
 
50
  lessons_manager: LessonsManager,
51
  }
52
 
53
+ #[derive(Debug)]
54
+ enum Error {
55
+ IoError(std::io::Error),
56
+ ConfigError(serde_yaml::Error),
57
+ }
58
+
59
+ async fn load_config() -> Result<Config, Error> {
60
+ let config_str = fs::read_to_string("config.yaml").await.map_err(|e| Error::IoError(e))?;
61
+ let config: Config = serde_yaml::from_str(config_str.as_str())
62
+ .map_err(|e| Error::ConfigError(e))?;
63
+ return Ok(config)
64
+ }
65
+
66
  #[tokio::main]
67
  async fn main() -> Result<(), std::io::Error> {
68
  tracing_subscriber::fmt::init();
69
+
70
+ let config = load_config().await.expect("failed to load config");
71
+ run_whisper(&config).await;
72
 
73
  let Opt {
74
  region,
src/whisper.rs ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use whisper_rs::WhisperContext;
2
+ use crate::config::Config;
3
+
4
+ pub(crate) async fn run_whisper(config: &Config) {
5
+ let ctx = WhisperContext::new(&*config.whisper.model).expect("failed to load whisper context");
6
+ let mut _state = ctx.create_state().expect("failed to create state");
7
+ let params = (&config.whisper).to_full_params();
8
+ _state.full(params, &[]).expect("TODO: panic message");
9
+ }
10
+
11
+ async fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
12
+ let pcm_i16 = input
13
+ .chunks_exact(2)
14
+ .map(|chunk| {
15
+ let mut buf = [0u8; 2];
16
+ buf.copy_from_slice(chunk);
17
+ i16::from_le_bytes(buf)
18
+ })
19
+ .collect::<Vec<i16>>();
20
+ let pcm_f32 = pcm_i16
21
+ .iter()
22
+ .map(|i| *i as f32 / i16::MAX as f32)
23
+ .collect::<Vec<f32>>();
24
+ pcm_f32
25
+ }
26
+
27
+ struct WhisperHandler {
28
+
29
+ }