Spaces:
Sleeping
Sleeping
use std::{ | |
collections::VecDeque, | |
fmt::{Debug, Display, Formatter}, | |
thread::sleep, | |
time::Duration, | |
sync::Arc | |
}; | |
use std::ops::Deref; | |
use fvad::SampleRate; | |
use tokio::sync::{broadcast, mpsc, oneshot}; | |
use tokio::time::Instant; | |
use tracing::{warn}; | |
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperError, WhisperState, WhisperToken, WhisperTokenData}; | |
use crate::{config::WhisperConfig, group::GroupedWithin}; | |
const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize; | |
pub struct Context { | |
inner: Arc<WhisperContext>, | |
} | |
impl Context { | |
pub fn new(model: &str) -> Result<Context, WhisperError> { | |
WhisperContext::new(model) | |
.map(|context| Self { inner: Arc::new(context) }) | |
} | |
pub fn create_handler(&self, config: WhisperConfig, prompt: String) -> WhisperHandler { | |
WhisperHandler::new(self.clone(), config, prompt) | |
} | |
} | |
impl Deref for Context { | |
type Target = WhisperContext; | |
fn deref(&self) -> &Self::Target { | |
&self.inner | |
} | |
} | |
pub enum Error { | |
WhisperError { | |
description: String, | |
error: WhisperError, | |
}, | |
} | |
impl Error { | |
fn whisper_error(description: &str, error: whisper_rs::WhisperError) -> Self { | |
Self::WhisperError { | |
description: description.to_string(), | |
error, | |
} | |
} | |
} | |
impl Display for Error { | |
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | |
match self { | |
Self::WhisperError { description, error } => { | |
write!(f, "WhisperError: {}: {}", description, error) | |
} | |
} | |
} | |
} | |
impl std::error::Error for Error { | |
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { | |
match self { | |
Self::WhisperError { error, .. } => Some(error), | |
} | |
} | |
} | |
pub enum Output { | |
Unstable(Segment), | |
Stable(Segment), | |
} | |
pub struct Segment { | |
pub start_timestamp: i64, | |
pub end_timestamp: i64, | |
pub text: String, | |
tokens: Vec<WhisperTokenData>, | |
} | |
pub struct WhisperHandler { | |
tx: mpsc::Sender<Vec<i16>>, | |
transcription_tx: broadcast::Sender<Vec<Output>>, | |
stop_handle: Option<oneshot::Sender<()>>, | |
} | |
impl WhisperHandler { | |
fn new(whisper_context: Context, config: WhisperConfig, prompt: String) -> Self { | |
let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3; | |
let (stop_handle, mut stop_signal) = oneshot::channel(); | |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128); | |
let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128); | |
let shared_transcription_tx = transcription_tx.clone(); | |
let task = move || { | |
let state = match whisper_context.create_state() { | |
Ok(state) => state, | |
Err(e) => { | |
tracing::error!("failed to create WhisperState: {}", e); | |
return Err(Error::whisper_error("failed to create WhisperState", e)) | |
} | |
}; | |
let preset_prompt_tokens = match whisper_context.tokenize(&prompt, config.max_prompt_tokens) { | |
Ok(tokens) => tokens, | |
Err(e) => { | |
tracing::error!("failed to tokenize prompt: {}", e); | |
return Err(Error::whisper_error("failed to tokenize prompt", e)) | |
} | |
}; | |
let mut vad = fvad::Fvad::new().expect("failed to create VAD") | |
.set_sample_rate(SampleRate::Rate16kHz); | |
let mut detector = Detector::new(state, &config, preset_prompt_tokens); | |
let mut grouped = GroupedWithin::new( | |
detector.n_samples_step, | |
Duration::from_millis(config.step_ms as u64), | |
pcm_rx, | |
u16::MAX as usize, | |
); | |
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() { | |
if detector.has_crossed_next_line() { | |
if let Some(segment) = detector.next_line() { | |
let segments = vec![Output::Stable(segment)]; | |
if let Err(e) = shared_transcription_tx.send(segments) { | |
tracing::error!("failed to send transcription: {}", e); | |
break; | |
}; | |
} | |
} | |
let new_pcm_f32 = match grouped.next() { | |
Err(mpsc::error::TryRecvError::Disconnected) => break, | |
Err(mpsc::error::TryRecvError::Empty) => { | |
sleep(Duration::from_millis(10)); | |
continue; | |
} | |
Ok(data) => { | |
let active_voice = data | |
.chunks(vad_slice_size) | |
.filter(|frame| { | |
if frame.len() != vad_slice_size { | |
true | |
} else { | |
vad.is_voice_frame(frame).unwrap_or(true) | |
} | |
}) | |
.collect::<Vec<_>>() | |
.concat(); | |
convert_integer_to_float_audio(&active_voice) | |
}, | |
}; | |
detector.feed(new_pcm_f32); | |
let segments = match detector.inference() { | |
Ok(result) => { | |
if result.is_empty() { | |
continue; | |
} | |
result | |
} | |
Err(err) => { | |
warn!("failed to inference: {}", err); | |
continue; | |
} | |
}; | |
let outputs = segments | |
.iter() | |
.map(|segment| Output::Unstable(segment.clone())) | |
.collect::<Vec<_>>(); | |
if let Err(e) = shared_transcription_tx.send(outputs) { | |
tracing::error!("failed to send transcription: {}", e); | |
break; | |
}; | |
} | |
Ok::<(), Error>(()) | |
}; | |
tokio::task::spawn_blocking(task); | |
Self { | |
tx: pcm_tx, | |
transcription_tx, | |
stop_handle: Some(stop_handle), | |
} | |
} | |
pub fn subscribe(&self) -> broadcast::Receiver<Vec<Output>> { | |
self.transcription_tx.subscribe() | |
} | |
pub async fn send_i16(&mut self, data: Vec<i16>) -> Result<(), mpsc::error::SendError<Vec<i16>>> { | |
self.tx.send(data).await | |
} | |
} | |
struct Detector<'a> { | |
state: WhisperState<'a>, | |
config: &'a WhisperConfig, | |
start_time: Instant, | |
segment: Option<Segment>, | |
line_num: usize, | |
preset_prompt_tokens: Vec<WhisperToken>, | |
n_samples_keep: usize, | |
n_samples_step: usize, | |
n_samples_len: usize, | |
prompt_tokens: Vec<WhisperToken>, | |
pcm_f32: VecDeque<f32>, | |
offset: usize, | |
} | |
impl <'a> Detector<'a> { | |
fn new( | |
state: WhisperState<'a>, | |
config: &'a WhisperConfig, | |
preset_prompt_tokens: Vec<WhisperToken>, | |
) -> Self { | |
Detector { | |
state, | |
config, | |
start_time: Instant::now(), | |
segment: None, | |
line_num: 0, | |
preset_prompt_tokens, | |
n_samples_keep: config.keep_ms * WHISPER_SAMPLE_RATE / 1000, | |
n_samples_step: config.step_ms * WHISPER_SAMPLE_RATE / 1000, | |
n_samples_len: config.length_ms * WHISPER_SAMPLE_RATE / 1000, | |
prompt_tokens: Default::default(), | |
pcm_f32: VecDeque::with_capacity(config.length_ms * WHISPER_SAMPLE_RATE / 1000), | |
offset: 0, | |
} | |
} | |
fn feed(&mut self, new_pcm_f32: Vec<f32>) { | |
self.pcm_f32.extend(new_pcm_f32); | |
if self.pcm_f32.len() < self.n_samples_len { | |
return; | |
} | |
} | |
fn inference(&mut self) -> Result<Vec<Segment>, Error> { | |
let params = self.config.params.to_full_params(self.prompt_tokens.as_slice()); | |
let start = std::time::Instant::now(); | |
let _ = self | |
.state | |
.full(params, self.pcm_f32.make_contiguous()) | |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?; | |
let end = std::time::Instant::now(); | |
if end - start > Duration::from_millis(self.config.step_ms as u64) { | |
// warn!( | |
// "full([{}]) took {} ms too slow", | |
// self.pcm_f32.len(), | |
// (end - start).as_millis() | |
// ); | |
} | |
let timestamp_offset: i64 = (self.offset * 1000 / WHISPER_SAMPLE_RATE) as i64; | |
let num_segments = self | |
.state | |
.full_n_segments() | |
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?; | |
let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize); | |
for i in 0..num_segments { | |
let start_timestamp: i64 = timestamp_offset | |
+ 10 * self | |
.state | |
.full_get_segment_t0(i) | |
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?; | |
let end_timestamp: i64 = timestamp_offset | |
+ 10 * self | |
.state | |
.full_get_segment_t1(i) | |
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?; | |
let segment = self | |
.state | |
.full_get_segment_text(i) | |
.map_err(|e| Error::whisper_error("failed to get segment", e))?; | |
let num_tokens = self | |
.state | |
.full_n_tokens(i) | |
.map_err(|e| Error::whisper_error("failed to get segment tokens", e))?; | |
let mut segment_tokens = Vec::with_capacity(num_tokens as usize); | |
for j in 0..num_tokens { | |
let token_data = self.state.full_get_token_data(i, j) | |
.map_err(|e| Error::whisper_error("failed to get token data", e))?; | |
segment_tokens.push(token_data); | |
} | |
segments.push(Segment { | |
start_timestamp, | |
end_timestamp, | |
text: segment.trim().to_string(), | |
tokens: segment_tokens, | |
}); | |
} | |
self.segment = segments.first().cloned(); | |
Ok(segments.to_vec()) | |
} | |
fn remember_prompt(&mut self) { | |
let Some(segment) = self.segment.as_ref() else { | |
return | |
}; | |
let tokens = segment | |
.tokens | |
.iter() | |
.map(|td| td.tid) | |
.collect::<Vec<WhisperToken>>(); | |
self.prompt_tokens.extend(tokens); | |
if self.prompt_tokens.len() > self.config.max_prompt_tokens { | |
let _ = self.prompt_tokens.drain(0..(self.prompt_tokens.len() - self.config.max_prompt_tokens)).len(); | |
} | |
} | |
fn has_crossed_next_line(&self) -> bool { | |
let now = Instant::now(); | |
let elapsed = now - self.start_time; | |
let line_number: usize = (elapsed.as_millis() / self.config.length_ms as u128) as usize; | |
line_number > self.line_num | |
} | |
fn next_line(&mut self) -> Option<Segment> { | |
if self.pcm_f32.len() > self.n_samples_keep { | |
let drain_size = self.pcm_f32.drain(0..(self.pcm_f32.len() - self.n_samples_keep)).len(); | |
self.offset += drain_size; | |
} else { | |
let size_will_clear = self.pcm_f32.len(); | |
self.pcm_f32.clear(); | |
self.offset += size_will_clear; | |
} | |
self.line_num += 1; | |
self.remember_prompt(); | |
self.segment.take() | |
} | |
} | |
impl Drop for WhisperHandler { | |
fn drop(&mut self) { | |
let Some(stop_handle) = self.stop_handle.take() else { | |
return warn!("WhisperHandler::drop() called without stop_handle"); | |
}; | |
if stop_handle.send(()).is_err() { | |
warn!("WhisperHandler::drop() failed to send stop signal"); | |
} | |
} | |
} | |
mod test { | |
use super::*; | |
use std::io::{stdout, Write}; | |
use hound; | |
use tracing_test; | |
use tracing::{info, debug}; | |
use crate::config::WhisperParams; | |
use lazy_static::lazy_static; | |
async fn print_output(output: Output) { | |
match output { | |
Output::Stable(stable) => { | |
print!("\x1b[2K\r"); | |
print!("{}\n", stable.text); | |
}, | |
Output::Unstable(unstable) => { | |
// back to previous line of console | |
print!("\x1b[2K\r"); | |
print!("{}", " ".repeat(100)); | |
print!("\x1b[2K\r"); | |
print!("{} ...", unstable.text); | |
} | |
} | |
stdout().flush().unwrap(); | |
} | |
lazy_static! { | |
static ref CONFIG: WhisperConfig = WhisperConfig { | |
length_ms: 5000, | |
step_ms: 500, | |
keep_ms: 200, | |
model: "models/ggml-large-v3.bin".to_string(), | |
max_prompt_tokens: 32, | |
context_confidence_threshold: 0.5, | |
params: WhisperParams { | |
n_threads: None, | |
max_tokens: None, | |
audio_ctx: None, | |
speed_up: None, | |
translate: None, | |
no_context: None, | |
print_special: None, | |
print_realtime: None, | |
print_progress: None, | |
token_timestamps: None, | |
no_timestamps: None, | |
temperature_inc: None, | |
entropy_threshold: None, | |
single_segment: Some(true), | |
suppress_non_speech_tokens: None, | |
n_max_text_ctx: None, | |
language: Some("en".to_string()), | |
} | |
}; | |
static ref CONTEXT: Context = Context::new(&CONFIG.model).expect("failed to create WhisperContext"); | |
} | |
async fn test_whisper_handler() { | |
let mut whisper_handler = CONTEXT | |
.create_handler(&CONFIG, "Harry Potter and the Philosopher's Stone".to_string()) | |
.expect("failed to create WhisperHandler"); | |
let wav = hound::WavReader::open("samples/ADHD_1A.wav") | |
.expect("failed to open wav"); | |
let spec = wav.spec(); | |
info!("{:?}", spec); | |
let samples = wav | |
.into_samples::<i16>() | |
.map(|s| s.unwrap()) | |
.collect::<Vec<i16>>(); | |
let chunks = samples.chunks(1600) | |
.map(|chunk| chunk.to_vec()) | |
.into_iter(); | |
let mut rx = whisper_handler.subscribe(); | |
let send_fut = async { | |
for chunk in chunks { | |
let _ = whisper_handler.send_i16(chunk).await.expect("failed to send sample"); | |
tokio::time::sleep(Duration::from_millis(100)).await; | |
} | |
}; | |
let recv_fut = async { | |
while let Ok(outputs) = rx.recv().await { | |
let Some(output) = outputs.first() else { | |
continue | |
}; | |
match output { | |
Output::Stable(stable) => { | |
debug!("{}", stable.text); | |
}, | |
Output::Unstable(_unstable) => { | |
} | |
} | |
} | |
}; | |
tokio::join!(send_fut, recv_fut); | |
} | |
} |