Spaces:
Sleeping
Sleeping
File size: 3,086 Bytes
ee7230e a4dee07 ee7230e a4dee07 819a8c6 5cb1ef7 a4dee07 819a8c6 a4dee07 ee7230e 819a8c6 ee7230e 819a8c6 a4dee07 ee7230e a4dee07 ee7230e 819a8c6 ee7230e 819a8c6 ee7230e 819a8c6 a4dee07 ee7230e a4dee07 ee7230e 819a8c6 ee7230e a4dee07 819a8c6 ee7230e 5b9ecd0 ee7230e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
use std::ffi::c_int;
use std::net::IpAddr;
use lazy_static::lazy_static;
use serde::{Deserialize};
use whisper_rs::FullParams;
#[derive(Debug)]
pub enum Error {
IoError(std::io::Error),
ConfigError(serde_yaml::Error),
}
pub(crate) fn load_config() -> Result<Config, Error> {
let config_str = std::fs::read_to_string("config.yaml").map_err(|e| Error::IoError(e))?;
let config: Config = serde_yaml::from_str(config_str.as_str())
.map_err(|e| Error::ConfigError(e))?;
return Ok(config)
}
lazy_static! {
pub static ref CONFIG: Config = load_config().expect("failed to load config");
}
#[derive(Debug, Deserialize, Clone)]
pub(crate) struct WhisperConfig {
pub(crate) params: WhisperParams,
pub(crate) step_ms: u32,
pub(crate) length_ms: u32,
pub(crate) keep_ms: u32,
pub(crate) model: String,
pub(crate) max_prompt_tokens: usize,
}
#[derive(Debug, Deserialize, Clone)]
pub(crate) struct WhisperParams {
pub(crate) n_threads: Option<usize>,
pub(crate) max_tokens: u32,
pub(crate) audio_ctx: u32,
pub(crate) speed_up: bool,
pub(crate) translate: bool,
pub(crate) no_fallback: bool,
pub(crate) print_special: bool,
pub(crate) print_realtime: bool,
pub(crate) print_progress: bool,
pub(crate) no_timestamps: bool,
pub(crate) temperature_inc: f32,
pub(crate) single_segment: bool,
// pub(crate) tinydiarize: bool,
pub(crate) language: Option<String>,
}
const NONE: [c_int;0] = [];
impl WhisperParams {
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
let mut param = FullParams::new(Default::default());
param.set_print_progress(self.print_progress);
param.set_print_special(self.print_special);
param.set_print_realtime(self.print_realtime);
param.set_print_timestamps(!self.no_timestamps);
param.set_translate(self.translate);
param.set_single_segment(false);
param.set_max_tokens(self.max_tokens as i32);
let lang = self.language.as_ref().map(|s| s.as_str());
param.set_language(lang);
let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
param.set_audio_ctx(self.audio_ctx as i32);
param.set_speed_up(self.speed_up);
// param.set_tdrz_enable(self.tinydiarize);
param.set_temperature_inc(self.temperature_inc);
param
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct Server {
pub(crate) port: u16,
pub(crate) host: IpAddr,
}
#[derive(Debug, Deserialize)]
pub struct Config {
pub(crate) whisper: WhisperConfig,
pub(crate) server: Server,
}
mod tests {
#[tokio::test]
async fn load() {
let config_str = tokio::fs::read_to_string("config.yaml").await.expect("failed to read config file");
let params: crate::config::Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
println!("{:?}", params);
}
} |