Spaces:
Sleeping
Sleeping
WHISPER!
Browse files- Cargo.lock +5 -2
- Cargo.toml +4 -1
- config.yaml +3 -4
- src/config.rs +34 -16
- src/group.rs +59 -0
- src/main.rs +15 -22
- src/whisper.rs +199 -15
Cargo.lock
CHANGED
@@ -1419,14 +1419,17 @@ dependencies = [
|
|
1419 |
"aws-sdk-translate",
|
1420 |
"clap",
|
1421 |
"futures-util",
|
|
|
1422 |
"poem",
|
1423 |
"serde",
|
1424 |
"serde_json",
|
1425 |
"serde_yaml",
|
1426 |
"tokio",
|
1427 |
"tokio-stream",
|
|
|
1428 |
"tracing-subscriber",
|
1429 |
"whisper-rs",
|
|
|
1430 |
]
|
1431 |
|
1432 |
[[package]]
|
@@ -2072,9 +2075,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
|
|
2072 |
|
2073 |
[[package]]
|
2074 |
name = "tracing"
|
2075 |
-
version = "0.1.
|
2076 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2077 |
-
checksum = "
|
2078 |
dependencies = [
|
2079 |
"log",
|
2080 |
"pin-project-lite",
|
|
|
1419 |
"aws-sdk-translate",
|
1420 |
"clap",
|
1421 |
"futures-util",
|
1422 |
+
"lazy_static",
|
1423 |
"poem",
|
1424 |
"serde",
|
1425 |
"serde_json",
|
1426 |
"serde_yaml",
|
1427 |
"tokio",
|
1428 |
"tokio-stream",
|
1429 |
+
"tracing",
|
1430 |
"tracing-subscriber",
|
1431 |
"whisper-rs",
|
1432 |
+
"whisper-rs-sys",
|
1433 |
]
|
1434 |
|
1435 |
[[package]]
|
|
|
2075 |
|
2076 |
[[package]]
|
2077 |
name = "tracing"
|
2078 |
+
version = "0.1.40"
|
2079 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2080 |
+
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
2081 |
dependencies = [
|
2082 |
"log",
|
2083 |
"pin-project-lite",
|
Cargo.toml
CHANGED
@@ -12,12 +12,15 @@ clap = { version = "4.4.6" , features = ["derive"]}
|
|
12 |
tokio = { version = "1.33.0" , features = ["full"] }
|
13 |
tokio-stream = "0.1.14"
|
14 |
async-stream = "0.3.5"
|
15 |
-
tracing-subscriber = "0.3.17"
|
16 |
futures-util = "0.3.28"
|
17 |
serde = { version = "1.0.189", features = ["derive"] }
|
18 |
serde_json = { version = "1.0.107", features = [] }
|
19 |
serde_yaml = "0.9.25"
|
20 |
whisper-rs = { version = "0.8.0" , features = ["coreml"] }
|
|
|
|
|
|
|
|
|
21 |
|
22 |
[dependencies.poem]
|
23 |
version = "1.3.58"
|
|
|
12 |
tokio = { version = "1.33.0" , features = ["full"] }
|
13 |
tokio-stream = "0.1.14"
|
14 |
async-stream = "0.3.5"
|
|
|
15 |
futures-util = "0.3.28"
|
16 |
serde = { version = "1.0.189", features = ["derive"] }
|
17 |
serde_json = { version = "1.0.107", features = [] }
|
18 |
serde_yaml = "0.9.25"
|
19 |
whisper-rs = { version = "0.8.0" , features = ["coreml"] }
|
20 |
+
whisper-rs-sys = "0.6.1"
|
21 |
+
tracing = "0.1.40"
|
22 |
+
tracing-subscriber = "0.3.17"
|
23 |
+
lazy_static = "1.4.0"
|
24 |
|
25 |
[dependencies.poem]
|
26 |
version = "1.3.58"
|
config.yaml
CHANGED
@@ -2,11 +2,10 @@ server:
|
|
2 |
port: 8080
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
-
n_threads: 4
|
6 |
step_ms: 500
|
7 |
length_ms: 5000
|
8 |
-
keep_ms:
|
9 |
-
capture_id: -1
|
10 |
max_tokens: 32
|
11 |
audio_ctx: 0
|
12 |
vad_thold: 0.6
|
@@ -15,7 +14,7 @@ whisper:
|
|
15 |
translate: false
|
16 |
no_fallback: false
|
17 |
print_special: false
|
18 |
-
no_context:
|
19 |
no_timestamps: false
|
20 |
tinydiarize: false
|
21 |
language: "en"
|
|
|
2 |
port: 8080
|
3 |
host: ::1
|
4 |
whisper:
|
5 |
+
# n_threads: 4
|
6 |
step_ms: 500
|
7 |
length_ms: 5000
|
8 |
+
keep_ms: 5000
|
|
|
9 |
max_tokens: 32
|
10 |
audio_ctx: 0
|
11 |
vad_thold: 0.6
|
|
|
14 |
translate: false
|
15 |
no_fallback: false
|
16 |
print_special: false
|
17 |
+
no_context: false
|
18 |
no_timestamps: false
|
19 |
tinydiarize: false
|
20 |
language: "en"
|
src/config.rs
CHANGED
@@ -1,27 +1,43 @@
|
|
1 |
use std::ffi::c_int;
|
2 |
-
use std::fs;
|
3 |
use std::net::IpAddr;
|
|
|
4 |
use serde::{Deserialize};
|
5 |
use whisper_rs::FullParams;
|
6 |
|
7 |
-
#[derive(Debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
pub(crate) struct WhisperParams {
|
9 |
pub(crate) n_threads: Option<usize>,
|
10 |
-
pub(crate) step_ms:
|
11 |
-
pub(crate) length_ms:
|
12 |
-
pub(crate) keep_ms:
|
13 |
-
pub(crate)
|
14 |
-
pub(crate)
|
15 |
-
pub(crate)
|
16 |
-
pub(crate)
|
17 |
-
pub(crate) freq_thold: f32,
|
18 |
pub(crate) speed_up: bool,
|
19 |
pub(crate) translate: bool,
|
20 |
pub(crate) no_fallback: bool,
|
21 |
pub(crate) print_special: bool,
|
22 |
pub(crate) no_context: bool,
|
23 |
pub(crate) no_timestamps: bool,
|
24 |
-
pub(crate) tinydiarize: bool,
|
25 |
pub(crate) language: Option<String>,
|
26 |
pub(crate) model: String,
|
27 |
}
|
@@ -29,20 +45,20 @@ pub(crate) struct WhisperParams {
|
|
29 |
const NONE: [c_int;0] = [];
|
30 |
|
31 |
impl WhisperParams {
|
32 |
-
pub(crate) fn to_full_params<'a, 'b>(&'a self) -> FullParams<'a, 'b> {
|
33 |
let mut param = FullParams::new(Default::default());
|
34 |
param.set_print_progress(false);
|
35 |
param.set_print_special(self.print_special);
|
36 |
param.set_print_realtime(false);
|
37 |
param.set_print_timestamps(!self.no_timestamps);
|
38 |
param.set_translate(self.translate);
|
39 |
-
param.set_single_segment(
|
40 |
-
param.set_max_tokens(self.max_tokens);
|
41 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
42 |
param.set_language(lang);
|
43 |
let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
|
44 |
param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
|
45 |
-
param.set_audio_ctx(self.audio_ctx);
|
46 |
param.set_speed_up(self.speed_up);
|
47 |
// param.set_tdrz_enable(self.tinydiarize);
|
48 |
if self.no_fallback {
|
@@ -50,6 +66,8 @@ impl WhisperParams {
|
|
50 |
}
|
51 |
if self.no_context {
|
52 |
param.set_tokens(&NONE);
|
|
|
|
|
53 |
}
|
54 |
|
55 |
param
|
@@ -63,7 +81,7 @@ pub(crate) struct Server {
|
|
63 |
}
|
64 |
|
65 |
#[derive(Debug, Deserialize)]
|
66 |
-
pub
|
67 |
pub(crate) whisper: WhisperParams,
|
68 |
pub(crate) server: Server,
|
69 |
}
|
|
|
1 |
use std::ffi::c_int;
|
|
|
2 |
use std::net::IpAddr;
|
3 |
+
use lazy_static::lazy_static;
|
4 |
use serde::{Deserialize};
|
5 |
use whisper_rs::FullParams;
|
6 |
|
7 |
+
#[derive(Debug)]
|
8 |
+
pub enum Error {
|
9 |
+
IoError(std::io::Error),
|
10 |
+
ConfigError(serde_yaml::Error),
|
11 |
+
}
|
12 |
+
|
13 |
+
pub(crate) fn load_config() -> Result<Config, Error> {
|
14 |
+
let config_str = std::fs::read_to_string("config.yaml").map_err(|e| Error::IoError(e))?;
|
15 |
+
let config: Config = serde_yaml::from_str(config_str.as_str())
|
16 |
+
.map_err(|e| Error::ConfigError(e))?;
|
17 |
+
return Ok(config)
|
18 |
+
}
|
19 |
+
|
20 |
+
lazy_static! {
|
21 |
+
pub static ref CONFIG: Config = load_config().expect("failed to load config");
|
22 |
+
}
|
23 |
+
|
24 |
+
#[derive(Debug, Deserialize, Clone)]
|
25 |
pub(crate) struct WhisperParams {
|
26 |
pub(crate) n_threads: Option<usize>,
|
27 |
+
// pub(crate) step_ms: u32,
|
28 |
+
// pub(crate) length_ms: u32,
|
29 |
+
pub(crate) keep_ms: u32,
|
30 |
+
pub(crate) max_tokens: u32,
|
31 |
+
pub(crate) audio_ctx: u32,
|
32 |
+
// pub(crate) vad_thold: f32,
|
33 |
+
// pub(crate) freq_thold: f32,
|
|
|
34 |
pub(crate) speed_up: bool,
|
35 |
pub(crate) translate: bool,
|
36 |
pub(crate) no_fallback: bool,
|
37 |
pub(crate) print_special: bool,
|
38 |
pub(crate) no_context: bool,
|
39 |
pub(crate) no_timestamps: bool,
|
40 |
+
// pub(crate) tinydiarize: bool,
|
41 |
pub(crate) language: Option<String>,
|
42 |
pub(crate) model: String,
|
43 |
}
|
|
|
45 |
const NONE: [c_int;0] = [];
|
46 |
|
47 |
impl WhisperParams {
|
48 |
+
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
|
49 |
let mut param = FullParams::new(Default::default());
|
50 |
param.set_print_progress(false);
|
51 |
param.set_print_special(self.print_special);
|
52 |
param.set_print_realtime(false);
|
53 |
param.set_print_timestamps(!self.no_timestamps);
|
54 |
param.set_translate(self.translate);
|
55 |
+
param.set_single_segment(false);
|
56 |
+
param.set_max_tokens(self.max_tokens as i32);
|
57 |
let lang = self.language.as_ref().map(|s| s.as_str());
|
58 |
param.set_language(lang);
|
59 |
let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
|
60 |
param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
|
61 |
+
param.set_audio_ctx(self.audio_ctx as i32);
|
62 |
param.set_speed_up(self.speed_up);
|
63 |
// param.set_tdrz_enable(self.tinydiarize);
|
64 |
if self.no_fallback {
|
|
|
66 |
}
|
67 |
if self.no_context {
|
68 |
param.set_tokens(&NONE);
|
69 |
+
} else {
|
70 |
+
param.set_tokens(&tokens);
|
71 |
}
|
72 |
|
73 |
param
|
|
|
81 |
}
|
82 |
|
83 |
#[derive(Debug, Deserialize)]
|
84 |
+
pub struct Config {
|
85 |
pub(crate) whisper: WhisperParams,
|
86 |
pub(crate) server: Server,
|
87 |
}
|
src/group.rs
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::time::Duration;
|
2 |
+
use tokio::{select};
|
3 |
+
use tokio::time::sleep;
|
4 |
+
use tokio::sync::mpsc::{Receiver, channel};
|
5 |
+
use tokio::sync::mpsc::error::TryRecvError;
|
6 |
+
|
7 |
+
pub struct GroupedWithin<I>
|
8 |
+
where I: 'static + Send {
|
9 |
+
outlet: Receiver<Vec<I>>
|
10 |
+
}
|
11 |
+
|
12 |
+
impl<I> GroupedWithin<I>
|
13 |
+
where I: 'static + Send {
|
14 |
+
pub fn new(group_size: usize,
|
15 |
+
window_time: Duration,
|
16 |
+
mut inlet: Receiver<Vec<I>>,
|
17 |
+
buffer: usize) -> Self {
|
18 |
+
let (tx, outlet) = channel::<Vec<I>>(buffer);
|
19 |
+
tokio::spawn(async move {
|
20 |
+
let mut window = Vec::with_capacity(group_size);
|
21 |
+
|
22 |
+
loop {
|
23 |
+
let grouped_fut = async {
|
24 |
+
while let Some(c) = inlet.recv().await {
|
25 |
+
window.extend(c);
|
26 |
+
if window.len() > group_size {
|
27 |
+
let will_send: Vec<I> = window.drain(0..group_size).collect();
|
28 |
+
return Some(will_send)
|
29 |
+
}
|
30 |
+
}
|
31 |
+
return None
|
32 |
+
};
|
33 |
+
|
34 |
+
let grouped = select! {
|
35 |
+
_ = sleep(window_time) => {
|
36 |
+
window.drain(..).collect()
|
37 |
+
},
|
38 |
+
grouped_opt = grouped_fut => {
|
39 |
+
match grouped_opt {
|
40 |
+
None => break,
|
41 |
+
Some(grouped) => grouped
|
42 |
+
}
|
43 |
+
}
|
44 |
+
};
|
45 |
+
|
46 |
+
if let Err(e) = tx.send(grouped).await {
|
47 |
+
tracing::error!("{}", e);
|
48 |
+
}
|
49 |
+
}
|
50 |
+
});
|
51 |
+
Self {
|
52 |
+
outlet
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
pub fn next(&mut self) -> Result<Vec<I>, TryRecvError> {
|
57 |
+
self.outlet.try_recv()
|
58 |
+
}
|
59 |
+
}
|
src/main.rs
CHANGED
@@ -17,17 +17,17 @@ use poem::web::websocket::{Message, WebSocket};
|
|
17 |
use futures_util::stream::StreamExt;
|
18 |
use poem::web::{Data, Query};
|
19 |
|
20 |
-
use tokio::{
|
21 |
use serde::{Deserialize, Serialize};
|
22 |
-
use whisper_rs::WhisperContext;
|
23 |
use lesson::{LessonsManager};
|
24 |
-
use crate::config::
|
25 |
use crate::lesson::Viseme;
|
26 |
-
use crate::whisper::
|
27 |
|
28 |
mod lesson;
|
29 |
mod config;
|
30 |
mod whisper;
|
|
|
31 |
|
32 |
|
33 |
#[derive(Debug, Parser)]
|
@@ -50,26 +50,11 @@ struct Context {
|
|
50 |
lessons_manager: LessonsManager,
|
51 |
}
|
52 |
|
53 |
-
#[derive(Debug)]
|
54 |
-
enum Error {
|
55 |
-
IoError(std::io::Error),
|
56 |
-
ConfigError(serde_yaml::Error),
|
57 |
-
}
|
58 |
-
|
59 |
-
async fn load_config() -> Result<Config, Error> {
|
60 |
-
let config_str = fs::read_to_string("config.yaml").await.map_err(|e| Error::IoError(e))?;
|
61 |
-
let config: Config = serde_yaml::from_str(config_str.as_str())
|
62 |
-
.map_err(|e| Error::ConfigError(e))?;
|
63 |
-
return Ok(config)
|
64 |
-
}
|
65 |
|
66 |
#[tokio::main]
|
67 |
async fn main() -> Result<(), std::io::Error> {
|
68 |
tracing_subscriber::fmt::init();
|
69 |
|
70 |
-
let config = load_config().await.expect("failed to load config");
|
71 |
-
run_whisper(&config).await;
|
72 |
-
|
73 |
let Opt {
|
74 |
region,
|
75 |
verbose,
|
@@ -107,7 +92,8 @@ async fn main() -> Result<(), std::io::Error> {
|
|
107 |
.at("lesson-speaker", StaticFileEndpoint::new("./static/index.html"))
|
108 |
.at("lesson-listener", StaticFileEndpoint::new("./static/index.html"))
|
109 |
.data(ctx);
|
110 |
-
let
|
|
|
111 |
let server = Server::new(listener);
|
112 |
|
113 |
server.run(app).await
|
@@ -127,11 +113,20 @@ async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, w
|
|
127 |
ws.on_upgrade(|mut socket| async move {
|
128 |
let origin_tx = lesson.voice_channel();
|
129 |
let mut transcribe_rx = lesson.transcript_channel();
|
|
|
|
|
130 |
loop {
|
131 |
select! {
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
msg = socket.next() => {
|
133 |
match msg.as_ref() {
|
134 |
Some(Ok(Message::Binary(bin))) => {
|
|
|
135 |
if origin_tx.send(bin.to_vec()).await.is_err() {
|
136 |
println!("tx closed");
|
137 |
break;
|
@@ -217,7 +212,6 @@ async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>,
|
|
217 |
},
|
218 |
voice = voice_rx.recv() => {
|
219 |
if let Ok(voice) = voice {
|
220 |
-
println!("Synthesized: {:?}", voice.len());
|
221 |
let _ = socket.send(Message::Binary(voice)).await;
|
222 |
}
|
223 |
},
|
@@ -225,7 +219,6 @@ async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>,
|
|
225 |
if let Ok(visemes) = visemes {
|
226 |
let evt = LiveLessonTextEvent::LipSync { visemes };
|
227 |
let json = serde_json::to_string(&evt).expect("failed to serialize");
|
228 |
-
println!("Visemes: {:?}", json);
|
229 |
let _ = socket.send(Message::Text(json)).await;
|
230 |
}
|
231 |
},
|
|
|
17 |
use futures_util::stream::StreamExt;
|
18 |
use poem::web::{Data, Query};
|
19 |
|
20 |
+
use tokio::{select};
|
21 |
use serde::{Deserialize, Serialize};
|
|
|
22 |
use lesson::{LessonsManager};
|
23 |
+
use crate::config::CONFIG;
|
24 |
use crate::lesson::Viseme;
|
25 |
+
use crate::whisper::WhisperHandler;
|
26 |
|
27 |
mod lesson;
|
28 |
mod config;
|
29 |
mod whisper;
|
30 |
+
mod group;
|
31 |
|
32 |
|
33 |
#[derive(Debug, Parser)]
|
|
|
50 |
lessons_manager: LessonsManager,
|
51 |
}
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
#[tokio::main]
|
55 |
async fn main() -> Result<(), std::io::Error> {
|
56 |
tracing_subscriber::fmt::init();
|
57 |
|
|
|
|
|
|
|
58 |
let Opt {
|
59 |
region,
|
60 |
verbose,
|
|
|
92 |
.at("lesson-speaker", StaticFileEndpoint::new("./static/index.html"))
|
93 |
.at("lesson-listener", StaticFileEndpoint::new("./static/index.html"))
|
94 |
.data(ctx);
|
95 |
+
let addr = format!("{}:{}", CONFIG.server.host, CONFIG.server.port);
|
96 |
+
let listener = TcpListener::bind(addr);
|
97 |
let server = Server::new(listener);
|
98 |
|
99 |
server.run(app).await
|
|
|
113 |
ws.on_upgrade(|mut socket| async move {
|
114 |
let origin_tx = lesson.voice_channel();
|
115 |
let mut transcribe_rx = lesson.transcript_channel();
|
116 |
+
let whisper = WhisperHandler::new(CONFIG.whisper.clone()).expect("failed to create whisper");
|
117 |
+
let mut whisper_transcribe_rx = whisper.subscribe();
|
118 |
loop {
|
119 |
select! {
|
120 |
+
w = whisper_transcribe_rx.recv() => {
|
121 |
+
let Ok(txt) = w else {
|
122 |
+
continue
|
123 |
+
};
|
124 |
+
println!("Whisper: {:?}", txt)
|
125 |
+
}
|
126 |
msg = socket.next() => {
|
127 |
match msg.as_ref() {
|
128 |
Some(Ok(Message::Binary(bin))) => {
|
129 |
+
let _ = whisper.send(bin.clone()).await; // whisper test
|
130 |
if origin_tx.send(bin.to_vec()).await.is_err() {
|
131 |
println!("tx closed");
|
132 |
break;
|
|
|
212 |
},
|
213 |
voice = voice_rx.recv() => {
|
214 |
if let Ok(voice) = voice {
|
|
|
215 |
let _ = socket.send(Message::Binary(voice)).await;
|
216 |
}
|
217 |
},
|
|
|
219 |
if let Ok(visemes) = visemes {
|
220 |
let evt = LiveLessonTextEvent::LipSync { visemes };
|
221 |
let json = serde_json::to_string(&evt).expect("failed to serialize");
|
|
|
222 |
let _ = socket.send(Message::Text(json)).await;
|
223 |
}
|
224 |
},
|
src/whisper.rs
CHANGED
@@ -1,14 +1,57 @@
|
|
1 |
-
use
|
2 |
-
use
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
}
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
let pcm_i16 = input
|
13 |
.chunks_exact(2)
|
14 |
.map(|chunk| {
|
@@ -17,13 +60,154 @@ async fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
|
|
17 |
i16::from_le_bytes(buf)
|
18 |
})
|
19 |
.collect::<Vec<i16>>();
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
}
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::ffi::c_int;
|
2 |
+
use std::fmt::{Debug, Display, Formatter};
|
3 |
+
use std::thread::sleep;
|
4 |
+
use std::time::Duration;
|
5 |
+
use lazy_static::lazy_static;
|
6 |
+
use tokio::sync::{broadcast, mpsc, oneshot};
|
7 |
+
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState};
|
8 |
+
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
9 |
+
use crate::config::{WhisperParams, CONFIG};
|
10 |
+
use crate::group::GroupedWithin;
|
11 |
|
12 |
+
lazy_static! {
|
13 |
+
static ref WHISPER_CONTEXT: WhisperContext = {
|
14 |
+
WhisperContext::new(&*CONFIG.whisper.model)
|
15 |
+
.expect("failed to create WhisperContext")
|
16 |
+
};
|
17 |
}
|
18 |
|
19 |
+
#[derive(Debug)]
|
20 |
+
pub(crate) enum Error {
|
21 |
+
WhisperError {
|
22 |
+
description: String,
|
23 |
+
error: whisper_rs::WhisperError,
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
impl Error {
|
28 |
+
fn whisper_error(description: &str, error: whisper_rs::WhisperError) -> Self {
|
29 |
+
Self::WhisperError {
|
30 |
+
description: description.to_string(),
|
31 |
+
error,
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
impl Display for Error {
|
37 |
+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
38 |
+
match self {
|
39 |
+
Self::WhisperError { description, error } => {
|
40 |
+
write!(f, "WhisperError: {}: {}", description, error)
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
impl std::error::Error for Error {
|
47 |
+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
48 |
+
match self {
|
49 |
+
Self::WhisperError { error, .. } => Some(error),
|
50 |
+
}
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
|
55 |
let pcm_i16 = input
|
56 |
.chunks_exact(2)
|
57 |
.map(|chunk| {
|
|
|
60 |
i16::from_le_bytes(buf)
|
61 |
})
|
62 |
.collect::<Vec<i16>>();
|
63 |
+
convert_integer_to_float_audio(pcm_i16.as_slice())
|
64 |
+
}
|
65 |
+
|
66 |
+
#[derive(Clone, Debug)]
|
67 |
+
pub struct Segment {
|
68 |
+
pub start_timestamp: i64,
|
69 |
+
pub end_timestamp: i64,
|
70 |
+
pub text: String,
|
71 |
+
}
|
72 |
+
|
73 |
+
pub struct WhisperHandler {
|
74 |
+
tx: mpsc::Sender<Vec<u8>>,
|
75 |
+
transcription_tx: broadcast::Sender<Vec<Segment>>,
|
76 |
+
stop_handle: Option<oneshot::Sender<()>>,
|
77 |
+
}
|
78 |
+
|
79 |
+
impl WhisperHandler {
|
80 |
+
pub(crate) fn new(config: WhisperParams) -> Result<Self, Error> {
|
81 |
+
let n_samples_keep: usize = (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize;
|
82 |
+
let (stop_handle, mut stop_signal) = oneshot::channel();
|
83 |
+
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
84 |
+
let mut grouped = GroupedWithin::new(
|
85 |
+
n_samples_keep,
|
86 |
+
Duration::from_secs(5),
|
87 |
+
pcm_rx,
|
88 |
+
1024
|
89 |
+
);
|
90 |
+
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
91 |
+
let shared_transcription_tx = transcription_tx.clone();
|
92 |
+
let mut state = WHISPER_CONTEXT.create_state()
|
93 |
+
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
94 |
+
tokio::task::spawn_blocking(move || {
|
95 |
+
let mut tokens: Vec<c_int> = Default::default();
|
96 |
+
let mut pcm_f32: Vec<f32> = Default::default();
|
97 |
+
while let Err(oneshot::error::TryRecvError::Empty) = stop_signal.try_recv() {
|
98 |
+
let new_pcm_f32 = match grouped.next() {
|
99 |
+
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
100 |
+
Err(mpsc::error::TryRecvError::Empty) => {
|
101 |
+
sleep(Duration::from_millis(10));
|
102 |
+
continue
|
103 |
+
}
|
104 |
+
Ok(data) => {
|
105 |
+
pcm_i16_to_f32(&data)
|
106 |
+
}
|
107 |
+
};
|
108 |
+
|
109 |
+
pcm_f32.extend(new_pcm_f32);
|
110 |
+
match inference(&mut state, &config, n_samples_keep, &mut tokens, &mut pcm_f32) {
|
111 |
+
Ok(segments) => {
|
112 |
+
if segments.is_empty() {
|
113 |
+
continue
|
114 |
+
}
|
115 |
+
if let Err(e) = shared_transcription_tx.send(segments) {
|
116 |
+
tracing::error!("failed to send transcription: {}", e);
|
117 |
+
break
|
118 |
+
}
|
119 |
+
}
|
120 |
+
Err(err) => {
|
121 |
+
tracing::error!("failed to run whisper: {}", err);
|
122 |
+
continue
|
123 |
+
// break
|
124 |
+
}
|
125 |
+
}
|
126 |
+
}
|
127 |
+
});
|
128 |
+
Ok(Self {
|
129 |
+
tx: pcm_tx,
|
130 |
+
transcription_tx,
|
131 |
+
stop_handle: Some(stop_handle),
|
132 |
+
})
|
133 |
+
}
|
134 |
+
|
135 |
+
pub fn subscribe(&self) -> broadcast::Receiver<Vec<Segment>> {
|
136 |
+
self.transcription_tx.subscribe()
|
137 |
+
}
|
138 |
+
|
139 |
+
pub async fn send(&self, data: Vec<u8>) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
|
140 |
+
self.tx.send(data).await
|
141 |
+
}
|
142 |
}
|
143 |
|
144 |
+
fn inference(
|
145 |
+
state: &mut WhisperState,
|
146 |
+
config: &WhisperParams,
|
147 |
+
n_samples_keep: usize,
|
148 |
+
prompt_tokens: &mut Vec<c_int>,
|
149 |
+
pcm_f32: &mut Vec<f32>
|
150 |
+
) -> Result<Vec<Segment>, Error> {
|
151 |
+
let params = config.to_full_params(prompt_tokens.as_slice());
|
152 |
+
|
153 |
+
let st = std::time::Instant::now();
|
154 |
+
let _ = state.full(params, pcm_f32.as_slice())
|
155 |
+
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
156 |
+
let et = std::time::Instant::now();
|
157 |
+
|
158 |
+
let num_segments = state
|
159 |
+
.full_n_segments()
|
160 |
+
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
161 |
+
let mut segments: Vec<Segment> = Vec::with_capacity(num_segments as usize);
|
162 |
+
for i in 0..num_segments {
|
163 |
+
let segment = state
|
164 |
+
.full_get_segment_text(i)
|
165 |
+
.map_err(|e| Error::whisper_error("failed to get segment", e))?;
|
166 |
+
let start_timestamp = state
|
167 |
+
.full_get_segment_t0(i)
|
168 |
+
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
169 |
+
let end_timestamp = state
|
170 |
+
.full_get_segment_t1(i)
|
171 |
+
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
172 |
+
tracing::debug!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
173 |
+
segments.push(Segment { start_timestamp, end_timestamp, text: segment });
|
174 |
+
}
|
175 |
|
176 |
+
if !config.no_context {
|
177 |
+
prompt_tokens.clear();
|
178 |
+
|
179 |
+
// keep the last n_samples_keep samples from pcm_f32
|
180 |
+
if pcm_f32.len() > n_samples_keep {
|
181 |
+
let _ = pcm_f32.drain(0..(pcm_f32.len() - n_samples_keep)).collect::<Vec<_>>();
|
182 |
+
}
|
183 |
+
|
184 |
+
let n_segments = state
|
185 |
+
.full_n_segments()
|
186 |
+
.map_err(|e| Error::whisper_error("failed to get number of segments", e))?;
|
187 |
+
for i in 0..n_segments {
|
188 |
+
let token_count = state
|
189 |
+
.full_n_tokens(i)
|
190 |
+
.map_err(|e| Error::whisper_error("failed to get number of tokens", e))?;
|
191 |
+
for j in 0..token_count {
|
192 |
+
let token = state
|
193 |
+
.full_get_token_id(i, j)
|
194 |
+
.map_err(|e| Error::whisper_error("failed to get token", e))?;
|
195 |
+
prompt_tokens.push(token);
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
|
200 |
+
tracing::trace!("took {}ms", (et - st).as_millis());
|
201 |
+
Ok(segments)
|
202 |
+
}
|
203 |
+
|
204 |
+
impl Drop for WhisperHandler {
|
205 |
+
fn drop(&mut self) {
|
206 |
+
let Some(stop_handle) = self.stop_handle.take() else {
|
207 |
+
return tracing::warn!("WhisperHandler::drop() called without stop_handle");
|
208 |
+
};
|
209 |
+
if let Err(_) = stop_handle.send(()) {
|
210 |
+
tracing::warn!("WhisperHandler::drop() failed to send stop signal");
|
211 |
+
}
|
212 |
+
}
|
213 |
+
}
|