Spaces:
Build error
Build error
| use crate::config::{WhisperConfig, CONFIG}; | |
| use crate::group::GroupedWithin; | |
| use lazy_static::lazy_static; | |
| use std::collections::VecDeque; | |
| use std::ffi::c_int; | |
| use std::fmt::{Debug, Display, Formatter}; | |
| use std::thread::sleep; | |
| use std::time::Duration; | |
| use tokio::sync::{broadcast, mpsc, oneshot}; | |
| use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken}; | |
| use whisper_rs_sys::WHISPER_SAMPLE_RATE; | |
| lazy_static! { | |
| static ref WHISPER_CONTEXT: WhisperContext = | |
| WhisperContext::new(&*CONFIG.whisper.model).expect("failed to create WhisperContext"); | |
| } | |
| pub(crate) enum Error { | |
| WhisperError { | |
| description: String, | |
| error: whisper_rs::WhisperError, | |
| }, | |
| } | |
| impl Error { | |
| fn whisper_error(description: &str, error: whisper_rs::WhisperError) -> Self { | |
| Self::WhisperError { | |
| description: description.to_string(), | |
| error, | |
| } | |
| } | |
| } | |
| impl Display for Error { | |
| fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | |
| match self { | |
| Self::WhisperError { description, error } => { | |
| write!(f, "WhisperError: {}: {}", description, error) | |
| } | |
| } | |
| } | |
| } | |
| impl std::error::Error for Error { | |
| fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { | |
| match self { | |
| Self::WhisperError { error, .. } => Some(error), | |
| } | |
| } | |
| } | |
| fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> { | |
| let pcm_i16 = input | |
| .chunks_exact(2) | |
| .map(|chunk| { | |
| let mut buf = [0u8; 2]; | |
| buf.copy_from_slice(chunk); | |
| i16::from_le_bytes(buf) | |
| }) | |
| .collect::<Vec<i16>>(); | |
| convert_integer_to_float_audio(pcm_i16.as_slice()) | |
| } | |
| pub struct Segment { | |
| pub start_timestamp: i64, | |
| pub end_timestamp: i64, | |
| pub text: String, | |
| tokens: Vec<c_int>, | |
| } | |
| pub struct WhisperHandler { | |
| tx: mpsc::Sender<Vec<u8>>, | |
| transcription_tx: broadcast::Sender<Vec<Segment>>, | |
| stop_handle: Option<oneshot::Sender<()>>, | |
| } | |
| impl WhisperHandler { | |
| pub(crate) fn new(config: WhisperConfig, prompt: String) -> Result<Self, Error> { | |
| let (stop_handle, mut stop_signal) = oneshot::channel(); | |
| let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128); | |
| let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128); | |
| let shared_transcription_tx = transcription_tx.clone(); | |
| let state = WHISPER_CONTEXT | |
| .create_state() | |
| .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?; | |
| let preset_prompt_tokens = WHISPER_CONTEXT | |
| .tokenize(prompt.as_str(), CONFIG.whisper.max_prompt_tokens) | |
| .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?; | |
| tokio::task::spawn_blocking(move || { | |
| let mut detector = Detector::new(state, &CONFIG.whisper, preset_prompt_tokens); | |
| let mut grouped = GroupedWithin::new( | |
| detector.n_samples_step * 2, | |
| Duration::from_millis(config.step_ms as u64), | |
| pcm_rx, | |
| u16::MAX as usize, | |
| ); | |
| while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() { | |
| let new_pcm_f32 = match grouped.next() { | |
| Err(mpsc::error::TryRecvError::Disconnected) => break, | |
| Err(mpsc::error::TryRecvError::Empty) => { | |
| sleep(Duration::from_millis(10)); | |
| continue; | |
| } | |
| Ok(data) => pcm_i16_to_f32(&data), | |
| }; | |
| detector.feed(new_pcm_f32); | |
| let segments = match detector.inference() { | |
| Ok(result) => { | |
| if result.is_empty() { | |
| continue; | |
| } | |
| result | |
| } | |
| Err(err) => { | |
| tracing::warn!("failed to inference: {}", err); | |
| continue; | |
| } | |
| }; | |
| if tracing::enabled!(tracing::Level::TRACE) { | |
| for segment in segments.iter() { | |
| tracing::trace!( | |
| "[{}-{}]s SEGMENT: {}", | |
| segment.start_timestamp as f32 / 1000.0, | |
| segment.end_timestamp as f32 / 1000.0, | |
| segment.text | |
| ); | |
| } | |
| } | |
| if let Err(e) = shared_transcription_tx.send(segments) { | |
| tracing::error!("failed to send transcription: {}", e); | |
| break; | |
| }; | |
| } | |
| }); | |
| Ok(Self { | |
| tx: pcm_tx, | |
| transcription_tx, | |
| stop_handle: Some(stop_handle), | |
| }) | |
| } | |
| pub fn subscribe(&self) -> broadcast::Receiver<Vec<Segment>> { | |
| self.transcription_tx.subscribe() | |
| } | |
| pub async fn send(&self, data: Vec<u8>) -> Result<(), mpsc::error::SendError<Vec<u8>>> { | |
| self.tx.send(data).await | |
| } | |
| } | |
| struct Detector { | |
| state: WhisperState<'static>, | |
| config: &'static WhisperConfig, | |
| preset_prompt_tokens: Vec<WhisperToken>, | |
| n_samples_keep: usize, | |
| n_samples_step: usize, | |
| n_samples_len: usize, | |
| prompt_tokens: Vec<c_int>, | |
| pcm_f32: VecDeque<f32>, | |
| offset: usize, | |
| stable_offset: usize, | |
| } | |
| impl Detector { | |
| fn new( | |
| state: WhisperState<'static>, | |
| config: &'static WhisperConfig, | |
| preset_prompt_tokens: Vec<WhisperToken>, | |
| ) -> Self { | |
| Detector { | |
| state, | |
| config, | |
| preset_prompt_tokens, | |
| n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize, | |
| n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize, | |
| n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize, | |
| prompt_tokens: Default::default(), | |
| pcm_f32: VecDeque::from(vec![0f32; 30 * WHISPER_SAMPLE_RATE as usize]), | |
| offset: 0, | |
| stable_offset: 0, | |
| } | |
| } | |
| fn feed(&mut self, new_pcm_f32: Vec<f32>) { | |
| self.pcm_f32.extend(new_pcm_f32); | |
| if self.pcm_f32.len() < self.n_samples_len { | |
| return; | |
| } | |
| let len_to_drain = self | |
| .pcm_f32 | |
| .drain(0..(self.pcm_f32.len() - self.n_samples_len)) | |
| .len(); | |
| self.offset += len_to_drain; | |
| } | |
| fn inference(&mut self) -> Result<Vec<Segment>, Error> { | |
| let prompt_tokens = [ | |
| self.preset_prompt_tokens.as_slice(), | |
| self.prompt_tokens.as_slice(), | |
| ] | |
| .concat(); | |
| let params = self.config.params.to_full_params(prompt_tokens.as_slice()); | |
| let start = std::time::Instant::now(); | |
| let _ = self | |
| .state | |
| .full(params, self.pcm_f32.make_contiguous()) | |
| .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?; | |
| let end = std::time::Instant::now(); | |
| if end - start > Duration::from_millis(self.config.step_ms as u64) { | |
| tracing::warn!( | |
| "full([{}]) took {} ms too slow", | |
| self.pcm_f32.len(), | |
| (end - start).as_millis() | |
| ); | |
| } | |
| let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64; | |
| let stable_offset: i64 = (self.stable_offset * 1000 / WHISPER_SAMPLE_RATE as usize) as i64; | |
| let num_segments = self | |
| .state | |
| .full_n_segments() | |
| .map_err(|e| Error::whisper_error("failed to get number of segments", e))?; | |
| let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize); | |
| for i in 0..num_segments { | |
| let end_timestamp: i64 = timestamp_offset | |
| + 10 * self | |
| .state | |
| .full_get_segment_t1(i) | |
| .map_err(|e| Error::whisper_error("failed to get end timestamp", e))?; | |
| if end_timestamp <= stable_offset { | |
| continue; | |
| } | |
| let start_timestamp: i64 = timestamp_offset | |
| + 10 * self | |
| .state | |
| .full_get_segment_t0(i) | |
| .map_err(|e| Error::whisper_error("failed to get start timestamp", e))?; | |
| let segment = self | |
| .state | |
| .full_get_segment_text(i) | |
| .map_err(|e| Error::whisper_error("failed to get segment", e))?; | |
| let num_tokens = self | |
| .state | |
| .full_n_tokens(i) | |
| .map_err(|e| Error::whisper_error("failed to get segment tokens", e))?; | |
| let mut segment_tokens = Vec::with_capacity(num_tokens as usize); | |
| for j in 0..num_tokens { | |
| segment_tokens.push( | |
| self.state | |
| .full_get_token_id(i, j) | |
| .map_err(|e| Error::whisper_error("failed to get token", e))?, | |
| ); | |
| } | |
| segments.push(Segment { | |
| start_timestamp, | |
| end_timestamp, | |
| text: segment.trim().to_string(), | |
| tokens: segment_tokens, | |
| }); | |
| } | |
| let Some((_last, init)) = segments.split_last() else { | |
| return Ok(Vec::default()); | |
| }; | |
| let Some((last_2_seg, _)) = init.split_last() else { | |
| return Ok(Vec::default()); | |
| }; | |
| let offset = (last_2_seg.end_timestamp - timestamp_offset) as usize / 1000 | |
| * WHISPER_SAMPLE_RATE as usize; | |
| self.stable_offset = offset; | |
| self.drop_stable_by_segments(init); | |
| Ok(init.into()) | |
| } | |
| fn drop_stable_by_segments(&mut self, stable_segments: &[Segment]) { | |
| let Some(last) = stable_segments.last() else { | |
| return; | |
| }; | |
| let drop_offset: usize = (last.end_timestamp as usize / 1000 * WHISPER_SAMPLE_RATE as usize | |
| - self.offset) as usize; | |
| let len_to_drain = self.pcm_f32.drain(0..drop_offset).len(); | |
| self.offset += len_to_drain; | |
| for segment in stable_segments.into_iter() { | |
| self.prompt_tokens.extend(&segment.tokens); | |
| } | |
| if self.prompt_tokens.len() > self.config.max_prompt_tokens { | |
| let _ = self | |
| .prompt_tokens | |
| .drain(0..(self.prompt_tokens.len() - self.config.max_prompt_tokens)) | |
| .len(); | |
| } | |
| } | |
| } | |
| impl Drop for WhisperHandler { | |
| fn drop(&mut self) { | |
| let Some(stop_handle) = self.stop_handle.take() else { | |
| return tracing::warn!("WhisperHandler::drop() called without stop_handle"); | |
| }; | |
| if let Err(_) = stop_handle.send(()) { | |
| tracing::warn!("WhisperHandler::drop() failed to send stop signal"); | |
| } | |
| } | |
| } | |