use std::{ collections::VecDeque, fmt::{Debug, Display, Formatter}, thread::sleep, time::Duration, sync::Arc }; use std::ops::Deref; use fvad::SampleRate; use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::time::Instant; use tracing::{warn}; use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperError, WhisperState, WhisperToken, WhisperTokenData}; use crate::{config::WhisperConfig, group::GroupedWithin}; const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize; #[derive(Clone)] pub struct Context { inner: Arc, } impl Context { pub fn new(model: &str) -> Result { WhisperContext::new(model) .map(|context| Self { inner: Arc::new(context) }) } pub fn create_handler(&self, config: WhisperConfig, prompt: String) -> WhisperHandler { WhisperHandler::new(self.clone(), config, prompt) } } impl Deref for Context { type Target = WhisperContext; fn deref(&self) -> &Self::Target { &self.inner } } #[derive(Debug)] pub enum Error { WhisperError { description: String, error: WhisperError, }, } impl Error { fn whisper_error(description: &str, error: whisper_rs::WhisperError) -> Self { Self::WhisperError { description: description.to_string(), error, } } } impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::WhisperError { description, error } => { write!(f, "WhisperError: {}: {}", description, error) } } } } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Self::WhisperError { error, .. } => Some(error), } } } #[derive(Clone, Debug)] pub enum Output { Unstable(Segment), Stable(Segment), } #[derive(Clone, Debug)] pub struct Segment { pub start_timestamp: i64, pub end_timestamp: i64, pub text: String, tokens: Vec, } pub struct WhisperHandler { tx: mpsc::Sender>, transcription_tx: broadcast::Sender>, stop_handle: Option>, } impl WhisperHandler { fn new(whisper_context: Context, config: WhisperConfig, prompt: String) -> Self { let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3; let (stop_handle, mut stop_signal) = oneshot::channel(); let (pcm_tx, pcm_rx) = mpsc::channel::>(128); let (transcription_tx, _) = broadcast::channel::>(128); let shared_transcription_tx = transcription_tx.clone(); let task = move || { let state = match whisper_context.create_state() { Ok(state) => state, Err(e) => { tracing::error!("failed to create WhisperState: {}", e); return Err(Error::whisper_error("failed to create WhisperState", e)) } }; let preset_prompt_tokens = match whisper_context.tokenize(&prompt, config.max_prompt_tokens) { Ok(tokens) => tokens, Err(e) => { tracing::error!("failed to tokenize prompt: {}", e); return Err(Error::whisper_error("failed to tokenize prompt", e)) } }; let mut vad = fvad::Fvad::new().expect("failed to create VAD") .set_sample_rate(SampleRate::Rate16kHz); let mut detector = Detector::new(state, &config, preset_prompt_tokens); let mut grouped = GroupedWithin::new( detector.n_samples_step, Duration::from_millis(config.step_ms as u64), pcm_rx, u16::MAX as usize, ); while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() { if detector.has_crossed_next_line() { if let Some(segment) = detector.next_line() { let segments = vec![Output::Stable(segment)]; if let Err(e) = shared_transcription_tx.send(segments) { tracing::error!("failed to send transcription: {}", e); break; }; } } let new_pcm_f32 = match grouped.next() { Err(mpsc::error::TryRecvError::Disconnected) => break, Err(mpsc::error::TryRecvError::Empty) => { sleep(Duration::from_millis(10)); continue; } Ok(data) => { let active_voice = data .chunks(vad_slice_size) .filter(|frame| { if frame.len() != vad_slice_size { true } else { vad.is_voice_frame(frame).unwrap_or(true) } }) .collect::>() .concat(); convert_integer_to_float_audio(&active_voice) }, }; detector.feed(new_pcm_f32); let segments = match detector.inference() { Ok(result) => { if result.is_empty() { continue; } result } Err(err) => { warn!("failed to inference: {}", err); continue; } }; let outputs = segments .iter() .map(|segment| Output::Unstable(segment.clone())) .collect::>(); if let Err(e) = shared_transcription_tx.send(outputs) { tracing::error!("failed to send transcription: {}", e); break; }; } Ok::<(), Error>(()) }; tokio::task::spawn_blocking(task); Self { tx: pcm_tx, transcription_tx, stop_handle: Some(stop_handle), } } pub fn subscribe(&self) -> broadcast::Receiver> { self.transcription_tx.subscribe() } pub async fn send_i16(&mut self, data: Vec) -> Result<(), mpsc::error::SendError>> { self.tx.send(data).await } } #[allow(dead_code)] struct Detector<'a> { state: WhisperState<'a>, config: &'a WhisperConfig, start_time: Instant, segment: Option, line_num: usize, preset_prompt_tokens: Vec, n_samples_keep: usize, n_samples_step: usize, n_samples_len: usize, prompt_tokens: Vec, pcm_f32: VecDeque, offset: usize, } impl <'a> Detector<'a> { fn new( state: WhisperState<'a>, config: &'a WhisperConfig, preset_prompt_tokens: Vec, ) -> Self { Detector { state, config, start_time: Instant::now(), segment: None, line_num: 0, preset_prompt_tokens, n_samples_keep: config.keep_ms * WHISPER_SAMPLE_RATE / 1000, n_samples_step: config.step_ms * WHISPER_SAMPLE_RATE / 1000, n_samples_len: config.length_ms * WHISPER_SAMPLE_RATE / 1000, prompt_tokens: Default::default(), pcm_f32: VecDeque::with_capacity(config.length_ms * WHISPER_SAMPLE_RATE / 1000), offset: 0, } } fn feed(&mut self, new_pcm_f32: Vec) { self.pcm_f32.extend(new_pcm_f32); if self.pcm_f32.len() < self.n_samples_len { return; } } fn inference(&mut self) -> Result, Error> { let params = self.config.params.to_full_params(self.prompt_tokens.as_slice()); let start = std::time::Instant::now(); let _ = self .state .full(params, self.pcm_f32.make_contiguous()) .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?; let end = std::time::Instant::now(); if end - start > Duration::from_millis(self.config.step_ms as u64) { // warn!( // "full([{}]) took {} ms too slow", // self.pcm_f32.len(), // (end - start).as_millis() // ); } let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE) as i64; let num_segments = self .state .full_n_segments() .map_err(|e| Error::whisper_error("failed to get number of segments", e))?; let mut segments: Vec = Vec::with_capacity(num_segments as usize); for i in 0..num_segments { let start_timestamp: i64 = timestamp_offset + 10 * self .state .full_get_segment_t0(i) .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?; let end_timestamp: i64 = timestamp_offset + 10 * self .state .full_get_segment_t1(i) .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?; let segment = self .state .full_get_segment_text(i) .map_err(|e| Error::whisper_error("failed to get segment", e))?; let num_tokens = self .state .full_n_tokens(i) .map_err(|e| Error::whisper_error("failed to get segment tokens", e))?; let mut segment_tokens = Vec::with_capacity(num_tokens as usize); for j in 0..num_tokens { let token_data = self.state.full_get_token_data(i, j) .map_err(|e| Error::whisper_error("failed to get token data", e))?; segment_tokens.push(token_data); } segments.push(Segment { start_timestamp, end_timestamp, text: segment.trim().to_string(), tokens: segment_tokens, }); } self.segment = segments.first().cloned(); Ok(segments.to_vec()) } fn remember_prompt(&mut self) { let Some(segment) = self.segment.as_ref() else { return }; let tokens = segment .tokens .iter() .map(|td| td.tid) .collect::>(); self.prompt_tokens.extend(tokens); if self.prompt_tokens.len() > self.config.max_prompt_tokens { let _ = self.prompt_tokens.drain(0..(self.prompt_tokens.len() - self.config.max_prompt_tokens)).len(); } } fn has_crossed_next_line(&self) -> bool { let now = Instant::now(); let elapsed = now - self.start_time; let line_number: usize = (elapsed.as_millis() / self.config.length_ms as u128) as usize; line_number > self.line_num } fn next_line(&mut self) -> Option { if self.pcm_f32.len() > self.n_samples_keep { let drain_size = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep)).len(); self.offset += drain_size; } else { let size_will_clear = self.pcm_f32.len(); self.pcm_f32.clear(); self.offset += size_will_clear; } self.line_num += 1; self.remember_prompt(); self.segment.take() } } impl Drop for WhisperHandler { fn drop(&mut self) { let Some(stop_handle) = self.stop_handle.take() else { return warn!("WhisperHandler::drop() called without stop_handle"); }; if stop_handle.send(()).is_err() { warn!("WhisperHandler::drop() failed to send stop signal"); } } } #[cfg(test)] mod test { use super::*; use std::io::{stdout, Write}; use hound; use tracing_test; use tracing::{info, debug}; use crate::config::WhisperParams; use lazy_static::lazy_static; async fn print_output(output: Output) { match output { Output::Stable(stable) => { print!("\x1b[2K\r"); print!("{}\n", stable.text); }, Output::Unstable(unstable) => { // back to previous line of console print!("\x1b[2K\r"); print!("{}", " ".repeat(100)); print!("\x1b[2K\r"); print!("{} ...", unstable.text); } } stdout().flush().unwrap(); } lazy_static! { static ref CONFIG: WhisperConfig = WhisperConfig { length_ms: 5000, step_ms: 500, keep_ms: 200, model: "models/ggml-large-v3.bin".to_string(), max_prompt_tokens: 32, context_confidence_threshold: 0.5, params: WhisperParams { n_threads: None, max_tokens: None, audio_ctx: None, speed_up: None, translate: None, no_context: None, print_special: None, print_realtime: None, print_progress: None, token_timestamps: None, no_timestamps: None, temperature_inc: None, entropy_threshold: None, single_segment: Some(true), suppress_non_speech_tokens: None, n_max_text_ctx: None, language: Some("en".to_string()), } }; static ref CONTEXT: Context = Context::new(&CONFIG.model).expect("failed to create WhisperContext"); } #[tokio::test] #[tracing_test::traced_test] async fn test_whisper_handler() { let mut whisper_handler = CONTEXT .create_handler(&CONFIG, "Harry Potter and the Philosopher's Stone".to_string()) .expect("failed to create WhisperHandler"); let wav = hound::WavReader::open("samples/ADHD_1A.wav") .expect("failed to open wav"); let spec = wav.spec(); info!("{:?}", spec); let samples = wav .into_samples::() .map(|s| s.unwrap()) .collect::>(); let chunks = samples.chunks(1600) .map(|chunk| chunk.to_vec()) .into_iter(); let mut rx = whisper_handler.subscribe(); let send_fut = async { for chunk in chunks { let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample"); tokio::time::sleep(Duration::from_millis(100)).await; } }; let recv_fut = async { while let Ok(outputs) = rx.recv().await { let Some(output) = outputs.first() else { continue }; match output { Output::Stable(stable) => { debug!("{}", stable.text); }, Output::Unstable(_unstable) => { } } } }; tokio::join!(send_fut, recv_fut); } }