File size: 3,086 Bytes
ee7230e
 
a4dee07
ee7230e
 
 
a4dee07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
819a8c6
 
5cb1ef7
 
a4dee07
819a8c6
 
 
 
 
 
 
a4dee07
 
ee7230e
 
 
 
819a8c6
 
ee7230e
819a8c6
 
a4dee07
ee7230e
 
 
 
 
 
a4dee07
ee7230e
819a8c6
ee7230e
819a8c6
ee7230e
 
819a8c6
a4dee07
ee7230e
 
 
 
a4dee07
ee7230e
 
819a8c6
ee7230e
 
 
 
 
 
 
 
 
 
 
 
a4dee07
819a8c6
ee7230e
 
 
5b9ecd0
 
 
 
 
 
 
ee7230e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use std::ffi::c_int;
use std::net::IpAddr;
use lazy_static::lazy_static;
use serde::{Deserialize};
use whisper_rs::FullParams;

#[derive(Debug)]
pub enum Error {
    IoError(std::io::Error),
    ConfigError(serde_yaml::Error),
}

pub(crate) fn load_config() -> Result<Config, Error> {
    let config_str = std::fs::read_to_string("config.yaml").map_err(|e| Error::IoError(e))?;
    let config: Config = serde_yaml::from_str(config_str.as_str())
        .map_err(|e| Error::ConfigError(e))?;
    return Ok(config)
}

lazy_static! {
    pub static ref CONFIG: Config = load_config().expect("failed to load config");
}

#[derive(Debug, Deserialize, Clone)]
pub(crate) struct WhisperConfig {
    pub(crate) params: WhisperParams,
    pub(crate) step_ms: u32,
    pub(crate) length_ms: u32,
    pub(crate) keep_ms: u32,
    pub(crate) model: String,
    pub(crate) max_prompt_tokens: usize,
}

#[derive(Debug, Deserialize, Clone)]
pub(crate) struct WhisperParams {
    pub(crate) n_threads: Option<usize>,
    pub(crate) max_tokens: u32,
    pub(crate) audio_ctx: u32,
    pub(crate) speed_up: bool,
    pub(crate) translate: bool,
    pub(crate) no_fallback: bool,
    pub(crate) print_special: bool,
    pub(crate) print_realtime: bool,
    pub(crate) print_progress: bool,
    pub(crate) no_timestamps: bool,
    pub(crate) temperature_inc: f32,
    pub(crate) single_segment: bool,
    // pub(crate) tinydiarize: bool,
    pub(crate) language: Option<String>,
}

const NONE: [c_int;0] = [];

impl WhisperParams {
    pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
        let mut param = FullParams::new(Default::default());
        param.set_print_progress(self.print_progress);
        param.set_print_special(self.print_special);
        param.set_print_realtime(self.print_realtime);
        param.set_print_timestamps(!self.no_timestamps);
        param.set_translate(self.translate);
        param.set_single_segment(false);
        param.set_max_tokens(self.max_tokens as i32);
        let lang = self.language.as_ref().map(|s| s.as_str());
        param.set_language(lang);
        let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
        param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
        param.set_audio_ctx(self.audio_ctx as i32);
        param.set_speed_up(self.speed_up);
        // param.set_tdrz_enable(self.tinydiarize);
        param.set_temperature_inc(self.temperature_inc);

        param
    }
}

#[derive(Debug, Deserialize)]
pub(crate) struct Server {
    pub(crate) port: u16,
    pub(crate) host: IpAddr,
}

#[derive(Debug, Deserialize)]
pub struct Config {
    pub(crate) whisper: WhisperConfig,
    pub(crate) server: Server,
}

mod tests {
    #[tokio::test]
    async fn load() {
        let config_str = tokio::fs::read_to_string("config.yaml").await.expect("failed to read config file");
        let params: crate::config::Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
        println!("{:?}", params);
    }
}