Spaces:
Sleeping
Sleeping
split whisper into a standalone crate
Browse files- Cargo.lock +23 -4
- Cargo.toml +9 -0
- src/asr/aws.rs +17 -0
- src/asr/whisper.rs +79 -0
- src/config.rs +5 -95
- src/main.rs +16 -10
- whisper/ggml-metal.metal +0 -0
- whisper/src/config.rs +95 -0
- {src → whisper/src}/group.rs +0 -0
- src/whisper.rs → whisper/src/handler.rs +82 -33
- whisper/src/main.rs +6 -0
Cargo.lock
CHANGED
@@ -1544,6 +1544,7 @@ version = "0.1.0"
|
|
1544 |
dependencies = [
|
1545 |
"anyhow",
|
1546 |
"async-stream",
|
|
|
1547 |
"aws-config",
|
1548 |
"aws-sdk-polly",
|
1549 |
"aws-sdk-transcribestreaming",
|
@@ -1552,6 +1553,7 @@ dependencies = [
|
|
1552 |
"futures-util",
|
1553 |
"fvad",
|
1554 |
"hound",
|
|
|
1555 |
"once_cell",
|
1556 |
"poem",
|
1557 |
"serde",
|
@@ -1562,6 +1564,7 @@ dependencies = [
|
|
1562 |
"tracing",
|
1563 |
"tracing-subscriber",
|
1564 |
"tracing-test",
|
|
|
1565 |
"whisper-rs",
|
1566 |
"whisper-rs-sys",
|
1567 |
]
|
@@ -2156,9 +2159,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
|
2156 |
|
2157 |
[[package]]
|
2158 |
name = "tokio"
|
2159 |
-
version = "1.
|
2160 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2161 |
-
checksum = "
|
2162 |
dependencies = [
|
2163 |
"backtrace",
|
2164 |
"bytes",
|
@@ -2174,9 +2177,9 @@ dependencies = [
|
|
2174 |
|
2175 |
[[package]]
|
2176 |
name = "tokio-macros"
|
2177 |
-
version = "2.
|
2178 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2179 |
-
checksum = "
|
2180 |
dependencies = [
|
2181 |
"proc-macro2",
|
2182 |
"quote",
|
@@ -2529,6 +2532,22 @@ dependencies = [
|
|
2529 |
"rustix",
|
2530 |
]
|
2531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2532 |
[[package]]
|
2533 |
name = "whisper-rs"
|
2534 |
version = "0.9.0-rc.2"
|
|
|
1544 |
dependencies = [
|
1545 |
"anyhow",
|
1546 |
"async-stream",
|
1547 |
+
"async-trait",
|
1548 |
"aws-config",
|
1549 |
"aws-sdk-polly",
|
1550 |
"aws-sdk-transcribestreaming",
|
|
|
1553 |
"futures-util",
|
1554 |
"fvad",
|
1555 |
"hound",
|
1556 |
+
"lazy_static",
|
1557 |
"once_cell",
|
1558 |
"poem",
|
1559 |
"serde",
|
|
|
1564 |
"tracing",
|
1565 |
"tracing-subscriber",
|
1566 |
"tracing-test",
|
1567 |
+
"whisper",
|
1568 |
"whisper-rs",
|
1569 |
"whisper-rs-sys",
|
1570 |
]
|
|
|
2159 |
|
2160 |
[[package]]
|
2161 |
name = "tokio"
|
2162 |
+
version = "1.34.0"
|
2163 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2164 |
+
checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9"
|
2165 |
dependencies = [
|
2166 |
"backtrace",
|
2167 |
"bytes",
|
|
|
2177 |
|
2178 |
[[package]]
|
2179 |
name = "tokio-macros"
|
2180 |
+
version = "2.2.0"
|
2181 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2182 |
+
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
2183 |
dependencies = [
|
2184 |
"proc-macro2",
|
2185 |
"quote",
|
|
|
2532 |
"rustix",
|
2533 |
]
|
2534 |
|
2535 |
+
[[package]]
|
2536 |
+
name = "whisper"
|
2537 |
+
version = "0.1.0"
|
2538 |
+
dependencies = [
|
2539 |
+
"fvad",
|
2540 |
+
"hound",
|
2541 |
+
"lazy_static",
|
2542 |
+
"once_cell",
|
2543 |
+
"serde",
|
2544 |
+
"tokio",
|
2545 |
+
"tracing",
|
2546 |
+
"tracing-test",
|
2547 |
+
"whisper-rs",
|
2548 |
+
"whisper-rs-sys",
|
2549 |
+
]
|
2550 |
+
|
2551 |
[[package]]
|
2552 |
name = "whisper-rs"
|
2553 |
version = "0.9.0-rc.2"
|
Cargo.toml
CHANGED
@@ -3,6 +3,9 @@ name = "polyhedron"
|
|
3 |
version = "0.1.0"
|
4 |
edition = "2021"
|
5 |
|
|
|
|
|
|
|
6 |
[dependencies]
|
7 |
anyhow = "1.0"
|
8 |
async-stream = "0.3"
|
@@ -21,6 +24,12 @@ tokio-stream = "0.1"
|
|
21 |
tracing = { version = "0.1", features = [] }
|
22 |
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
23 |
fvad = "0.1"
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
[dependencies.poem]
|
26 |
version = "1.3"
|
|
|
3 |
version = "0.1.0"
|
4 |
edition = "2021"
|
5 |
|
6 |
+
[workspace]
|
7 |
+
members = ["whisper"]
|
8 |
+
|
9 |
[dependencies]
|
10 |
anyhow = "1.0"
|
11 |
async-stream = "0.3"
|
|
|
24 |
tracing = { version = "0.1", features = [] }
|
25 |
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
26 |
fvad = "0.1"
|
27 |
+
whisper = { path="whisper", optional = true }
|
28 |
+
async-trait = "0.1.74"
|
29 |
+
lazy_static = { version = "1.4.0", features = [] }
|
30 |
+
|
31 |
+
[features]
|
32 |
+
whisper = ["dep:whisper"]
|
33 |
|
34 |
[dependencies.poem]
|
35 |
version = "1.3"
|
src/asr/aws.rs
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use async_trait::async_trait;
|
2 |
+
use tokio::sync::broadcast::Receiver;
|
3 |
+
use crate::asr::{ASR, Event};
|
4 |
+
|
5 |
+
struct AWS_ASR {
|
6 |
+
aws: aws_sdk_transcribestreaming::Client,
|
7 |
+
}
|
8 |
+
#[async_trait]
|
9 |
+
impl ASR for AWS_ASR {
|
10 |
+
async fn frame(&mut self, frame: &[i16]) -> anyhow::Result<()> {
|
11 |
+
todo!()
|
12 |
+
}
|
13 |
+
|
14 |
+
fn subscribe(&mut self) -> Receiver<Event> {
|
15 |
+
todo!()
|
16 |
+
}
|
17 |
+
}
|
src/asr/whisper.rs
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[cfg(feature = "whisper")]
|
2 |
+
pub mod whisper_asr {
|
3 |
+
use async_trait::async_trait;
|
4 |
+
use tokio::{select, spawn};
|
5 |
+
use tokio::sync::broadcast::Receiver;
|
6 |
+
use tokio::sync::broadcast::error::RecvError;
|
7 |
+
use lazy_static::lazy_static;
|
8 |
+
use whisper::config::WhisperConfig;
|
9 |
+
|
10 |
+
extern crate whisper;
|
11 |
+
|
12 |
+
use whisper::handler::{Error, Output, WhisperHandler, Context};
|
13 |
+
use crate::asr::{ASR, Event};
|
14 |
+
use crate::config::SETTINGS;
|
15 |
+
|
16 |
+
lazy_static! {
|
17 |
+
pub static ref CONTEXT: Context = Context::new(&SETTINGS.whisper.model)
|
18 |
+
.expect("Failed to initialize whisper context");
|
19 |
+
}
|
20 |
+
|
21 |
+
pub struct Whisper_ASR {
|
22 |
+
whisper: WhisperHandler,
|
23 |
+
tx: tokio::sync::broadcast::Sender<Event>,
|
24 |
+
}
|
25 |
+
|
26 |
+
impl Whisper_ASR {
|
27 |
+
pub async fn from_config() -> Result<Whisper_ASR, Error> {
|
28 |
+
let whisper = CONTEXT.create_handler(&SETTINGS.whisper, "".to_string())?;
|
29 |
+
let mut output_rx = whisper.subscribe();
|
30 |
+
let (tx, _) = tokio::sync::broadcast::channel(64);
|
31 |
+
let shared_tx = tx.clone();
|
32 |
+
let fut = async move {
|
33 |
+
loop {
|
34 |
+
select! {
|
35 |
+
poll = output_rx.recv() => {
|
36 |
+
match poll {
|
37 |
+
Ok(outputs) => {
|
38 |
+
for output in outputs {
|
39 |
+
let res = match output {
|
40 |
+
Output::Stable(segment) => tx.send(Event {
|
41 |
+
transcript: segment.text,
|
42 |
+
is_final: true,
|
43 |
+
}),
|
44 |
+
Output::Unstable(segment) => tx.send(Event {
|
45 |
+
transcript: segment.text,
|
46 |
+
is_final: false,
|
47 |
+
}),
|
48 |
+
};
|
49 |
+
if let Err(e) = res {
|
50 |
+
tracing::warn!("Failed to send whisper event: {}", e);
|
51 |
+
break
|
52 |
+
}
|
53 |
+
}
|
54 |
+
},
|
55 |
+
Err(RecvError::Closed) => break,
|
56 |
+
Err(RecvError::Lagged(lagged)) => {
|
57 |
+
tracing::warn!("Whisper ASR output lagged: {}", lagged);
|
58 |
+
}
|
59 |
+
}
|
60 |
+
},
|
61 |
+
}
|
62 |
+
}
|
63 |
+
};
|
64 |
+
spawn(fut);
|
65 |
+
Ok(Self { whisper, tx: shared_tx })
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
#[async_trait]
|
70 |
+
impl ASR for Whisper_ASR {
|
71 |
+
async fn frame(&mut self, frame: &[i16]) -> anyhow::Result<()> {
|
72 |
+
Ok(self.whisper.send_i16(frame.to_vec()).await?)
|
73 |
+
}
|
74 |
+
|
75 |
+
fn subscribe(&mut self) -> Receiver<Event> {
|
76 |
+
self.tx.subscribe()
|
77 |
+
}
|
78 |
+
}
|
79 |
+
}
|
src/config.rs
CHANGED
@@ -1,106 +1,15 @@
|
|
1 |
-
use std::{env,
|
2 |
|
3 |
use config::{Config, Environment, File};
|
4 |
use once_cell::sync::Lazy;
|
5 |
use serde::Deserialize;
|
6 |
-
use whisper_rs::{FullParams};
|
7 |
use tracing::debug;
|
|
|
|
|
8 |
|
9 |
pub(crate) static SETTINGS: Lazy<Settings> =
|
10 |
Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
|
11 |
|
12 |
-
#[derive(Debug, Deserialize, Clone)]
|
13 |
-
pub(crate) struct WhisperConfig {
|
14 |
-
pub(crate) params: WhisperParams,
|
15 |
-
pub(crate) step_ms: usize,
|
16 |
-
pub(crate) length_ms: usize,
|
17 |
-
pub(crate) keep_ms: usize,
|
18 |
-
pub(crate) model: String,
|
19 |
-
pub(crate) max_prompt_tokens: usize,
|
20 |
-
pub(crate) context_confidence_threshold: f32,
|
21 |
-
}
|
22 |
-
|
23 |
-
#[allow(dead_code)]
|
24 |
-
#[derive(Debug, Deserialize, Clone)]
|
25 |
-
pub(crate) struct WhisperParams {
|
26 |
-
pub(crate) n_threads: Option<usize>,
|
27 |
-
pub(crate) max_tokens: Option<u32>,
|
28 |
-
pub(crate) audio_ctx: Option<u32>,
|
29 |
-
pub(crate) speed_up: Option<bool>,
|
30 |
-
pub(crate) translate: Option<bool>,
|
31 |
-
pub(crate) no_context: Option<bool>,
|
32 |
-
pub(crate) print_special: Option<bool>,
|
33 |
-
pub(crate) print_realtime: Option<bool>,
|
34 |
-
pub(crate) print_progress: Option<bool>,
|
35 |
-
pub(crate) token_timestamps: Option<bool>,
|
36 |
-
pub(crate) no_timestamps: Option<bool>,
|
37 |
-
pub(crate) temperature_inc: Option<f32>,
|
38 |
-
pub(crate) entropy_threshold: Option<f32>,
|
39 |
-
pub(crate) single_segment: Option<bool>,
|
40 |
-
pub(crate) suppress_non_speech_tokens: Option<bool>,
|
41 |
-
pub(crate) n_max_text_ctx: Option<usize>,
|
42 |
-
// pub(crate) tinydiarize: bool,
|
43 |
-
pub(crate) language: Option<String>,
|
44 |
-
}
|
45 |
-
|
46 |
-
impl WhisperParams {
|
47 |
-
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [c_int]) -> FullParams<'a, 'b> {
|
48 |
-
let mut param = FullParams::new(Default::default());
|
49 |
-
if let Some(print_progress) = self.print_progress.as_ref() {
|
50 |
-
param.set_print_progress(*print_progress);
|
51 |
-
}
|
52 |
-
if let Some(print_special) = self.print_special.as_ref() {
|
53 |
-
param.set_print_special(*print_special);
|
54 |
-
}
|
55 |
-
if let Some(print_realtime) = self.print_realtime.as_ref() {
|
56 |
-
param.set_print_realtime(*print_realtime);
|
57 |
-
}
|
58 |
-
if let Some(single_segment) = self.single_segment.as_ref() {
|
59 |
-
param.set_single_segment(*single_segment);
|
60 |
-
}
|
61 |
-
if let Some(no_timestamps) = self.no_timestamps.as_ref() {
|
62 |
-
param.set_print_timestamps(!no_timestamps);
|
63 |
-
}
|
64 |
-
if let Some(token_timestamps) = self.token_timestamps.as_ref() {
|
65 |
-
param.set_token_timestamps(*token_timestamps);
|
66 |
-
}
|
67 |
-
if let Some(translate) = self.translate.as_ref() {
|
68 |
-
param.set_translate(*translate);
|
69 |
-
}
|
70 |
-
if let Some(max_tokens) = self.max_tokens.as_ref() {
|
71 |
-
param.set_max_tokens(*max_tokens as i32);
|
72 |
-
}
|
73 |
-
param.set_language(self.language.as_deref());
|
74 |
-
if let Some(n_threads) = self.n_threads.as_ref() {
|
75 |
-
param.set_n_threads(*n_threads as i32);
|
76 |
-
}
|
77 |
-
if let Some(audio_ctx) = self.audio_ctx.as_ref() {
|
78 |
-
param.set_audio_ctx(*audio_ctx as i32);
|
79 |
-
}
|
80 |
-
if let Some(speed_up) = self.speed_up.as_ref() {
|
81 |
-
param.set_speed_up(*speed_up);
|
82 |
-
}
|
83 |
-
// param.set_tdrz_enable(self.tinydiarize);
|
84 |
-
if let Some(temperature_inc) = self.temperature_inc.as_ref() {
|
85 |
-
param.set_temperature_inc(*temperature_inc);
|
86 |
-
}
|
87 |
-
if let Some(suppress_non_speech_tokens) = self.suppress_non_speech_tokens.as_ref() {
|
88 |
-
param.set_suppress_non_speech_tokens(*suppress_non_speech_tokens);
|
89 |
-
}
|
90 |
-
if let Some(no_context) = self.no_context.as_ref() {
|
91 |
-
param.set_no_context(*no_context);
|
92 |
-
}
|
93 |
-
if let Some(entropy_threshold) = self.entropy_threshold.as_ref() {
|
94 |
-
param.set_entropy_thold(*entropy_threshold);
|
95 |
-
}
|
96 |
-
if let Some(n_max_text_ctx) = self.n_max_text_ctx.as_ref() {
|
97 |
-
param.set_n_max_text_ctx(*n_max_text_ctx as i32);
|
98 |
-
}
|
99 |
-
|
100 |
-
param.set_tokens(tokens);
|
101 |
-
param
|
102 |
-
}
|
103 |
-
}
|
104 |
|
105 |
#[derive(Debug, Deserialize)]
|
106 |
pub(crate) struct Server {
|
@@ -110,7 +19,8 @@ pub(crate) struct Server {
|
|
110 |
|
111 |
#[derive(Debug, Deserialize)]
|
112 |
pub struct Settings {
|
113 |
-
|
|
|
114 |
pub(crate) server: Server,
|
115 |
}
|
116 |
|
|
|
1 |
+
use std::{env, net::IpAddr};
|
2 |
|
3 |
use config::{Config, Environment, File};
|
4 |
use once_cell::sync::Lazy;
|
5 |
use serde::Deserialize;
|
|
|
6 |
use tracing::debug;
|
7 |
+
#[cfg(feature = "whisper")]
|
8 |
+
use crate::whisper;
|
9 |
|
10 |
pub(crate) static SETTINGS: Lazy<Settings> =
|
11 |
Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
#[derive(Debug, Deserialize)]
|
15 |
pub(crate) struct Server {
|
|
|
19 |
|
20 |
#[derive(Debug, Deserialize)]
|
21 |
pub struct Settings {
|
22 |
+
#[cfg(feature = "whisper")]
|
23 |
+
pub(crate) whisper: whisper::config::WhisperConfig,
|
24 |
pub(crate) server: Server,
|
25 |
}
|
26 |
|
src/main.rs
CHANGED
@@ -5,6 +5,8 @@
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
|
|
|
|
8 |
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
|
9 |
use futures_util::{stream::StreamExt, SinkExt};
|
10 |
use poem::{
|
@@ -22,12 +24,13 @@ use tokio::select;
|
|
22 |
use tracing::debug;
|
23 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
24 |
|
25 |
-
use crate::{config::*, lesson
|
|
|
|
|
26 |
|
27 |
mod config;
|
28 |
-
mod group;
|
29 |
mod lesson;
|
30 |
-
mod
|
31 |
|
32 |
#[derive(Clone)]
|
33 |
struct Context {
|
@@ -104,20 +107,23 @@ async fn stream_speaker(
|
|
104 |
ws.on_upgrade(|mut socket| async move {
|
105 |
let _origin_tx = lesson.voice_channel();
|
106 |
let mut transcribe_rx = lesson.transcript_channel();
|
107 |
-
|
|
|
108 |
.expect("failed to create whisper");
|
|
|
109 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
110 |
loop {
|
111 |
select! {
|
112 |
-
w = whisper_transcribe_rx.recv() => {
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
}
|
118 |
msg = socket.next() => {
|
119 |
match msg.as_ref() {
|
120 |
Some(Ok(Message::Binary(bin))) => {
|
|
|
121 |
let _ = whisper.send_bytes(bin.to_vec()).await; // whisper test
|
122 |
// if let Err(e) = origin_tx.send(bin.to_vec()).await {
|
123 |
// tracing::warn!("failed to send voice: {}", e);
|
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
8 |
+
#[cfg(feature = "whisper")]
|
9 |
+
extern crate whisper;
|
10 |
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
|
11 |
use futures_util::{stream::StreamExt, SinkExt};
|
12 |
use poem::{
|
|
|
24 |
use tracing::debug;
|
25 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
26 |
|
27 |
+
use crate::{config::*, lesson::*};
|
28 |
+
#[cfg(feature = "whisper")]
|
29 |
+
use crate::whisper::*;
|
30 |
|
31 |
mod config;
|
|
|
32 |
mod lesson;
|
33 |
+
mod asr;
|
34 |
|
35 |
#[derive(Clone)]
|
36 |
struct Context {
|
|
|
107 |
ws.on_upgrade(|mut socket| async move {
|
108 |
let _origin_tx = lesson.voice_channel();
|
109 |
let mut transcribe_rx = lesson.transcript_channel();
|
110 |
+
#[cfg(feature = "whisper")]
|
111 |
+
let mut whisper = asr::whisper::whisper_asr::CONTEXT.create_handler(&SETTINGS.whisper, prompt)
|
112 |
.expect("failed to create whisper");
|
113 |
+
#[cfg(feature = "whisper")]
|
114 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
115 |
loop {
|
116 |
select! {
|
117 |
+
// w = whisper_transcribe_rx.recv() => {
|
118 |
+
// let Ok(_txt) = w else {
|
119 |
+
// // TODO: handle msg
|
120 |
+
// continue
|
121 |
+
// };
|
122 |
+
// }
|
123 |
msg = socket.next() => {
|
124 |
match msg.as_ref() {
|
125 |
Some(Ok(Message::Binary(bin))) => {
|
126 |
+
#[cfg(feature = "whisper")]
|
127 |
let _ = whisper.send_bytes(bin.to_vec()).await; // whisper test
|
128 |
// if let Err(e) = origin_tx.send(bin.to_vec()).await {
|
129 |
// tracing::warn!("failed to send voice: {}", e);
|
whisper/ggml-metal.metal
ADDED
The diff for this file is too large to render.
See raw diff
|
|
whisper/src/config.rs
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use serde::Deserialize;
|
2 |
+
use whisper_rs::{FullParams, WhisperToken};
|
3 |
+
|
4 |
+
#[derive(Debug, Deserialize, Clone)]
|
5 |
+
pub struct WhisperConfig {
|
6 |
+
pub(crate) params: WhisperParams,
|
7 |
+
pub(crate) step_ms: usize,
|
8 |
+
pub(crate) length_ms: usize,
|
9 |
+
pub(crate) keep_ms: usize,
|
10 |
+
pub model: String,
|
11 |
+
pub(crate) max_prompt_tokens: usize,
|
12 |
+
pub(crate) context_confidence_threshold: f32,
|
13 |
+
}
|
14 |
+
|
15 |
+
#[allow(dead_code)]
|
16 |
+
#[derive(Debug, Deserialize, Clone)]
|
17 |
+
pub(crate) struct WhisperParams {
|
18 |
+
pub(crate) n_threads: Option<usize>,
|
19 |
+
pub(crate) max_tokens: Option<u32>,
|
20 |
+
pub(crate) audio_ctx: Option<u32>,
|
21 |
+
pub(crate) speed_up: Option<bool>,
|
22 |
+
pub(crate) translate: Option<bool>,
|
23 |
+
pub(crate) no_context: Option<bool>,
|
24 |
+
pub(crate) print_special: Option<bool>,
|
25 |
+
pub(crate) print_realtime: Option<bool>,
|
26 |
+
pub(crate) print_progress: Option<bool>,
|
27 |
+
pub(crate) token_timestamps: Option<bool>,
|
28 |
+
pub(crate) no_timestamps: Option<bool>,
|
29 |
+
pub(crate) temperature_inc: Option<f32>,
|
30 |
+
pub(crate) entropy_threshold: Option<f32>,
|
31 |
+
pub(crate) single_segment: Option<bool>,
|
32 |
+
pub(crate) suppress_non_speech_tokens: Option<bool>,
|
33 |
+
pub(crate) n_max_text_ctx: Option<usize>,
|
34 |
+
// pub(crate) tinydiarize: bool,
|
35 |
+
pub(crate) language: Option<String>,
|
36 |
+
}
|
37 |
+
|
38 |
+
impl WhisperParams {
|
39 |
+
pub(crate) fn to_full_params<'a, 'b>(&'a self, tokens: &'b [WhisperToken]) -> FullParams<'a, 'b> {
|
40 |
+
let mut param = FullParams::new(Default::default());
|
41 |
+
if let Some(print_progress) = self.print_progress.as_ref() {
|
42 |
+
param.set_print_progress(*print_progress);
|
43 |
+
}
|
44 |
+
if let Some(print_special) = self.print_special.as_ref() {
|
45 |
+
param.set_print_special(*print_special);
|
46 |
+
}
|
47 |
+
if let Some(print_realtime) = self.print_realtime.as_ref() {
|
48 |
+
param.set_print_realtime(*print_realtime);
|
49 |
+
}
|
50 |
+
if let Some(single_segment) = self.single_segment.as_ref() {
|
51 |
+
param.set_single_segment(*single_segment);
|
52 |
+
}
|
53 |
+
if let Some(no_timestamps) = self.no_timestamps.as_ref() {
|
54 |
+
param.set_print_timestamps(!no_timestamps);
|
55 |
+
}
|
56 |
+
if let Some(token_timestamps) = self.token_timestamps.as_ref() {
|
57 |
+
param.set_token_timestamps(*token_timestamps);
|
58 |
+
}
|
59 |
+
if let Some(translate) = self.translate.as_ref() {
|
60 |
+
param.set_translate(*translate);
|
61 |
+
}
|
62 |
+
if let Some(max_tokens) = self.max_tokens.as_ref() {
|
63 |
+
param.set_max_tokens(*max_tokens as i32);
|
64 |
+
}
|
65 |
+
param.set_language(self.language.as_deref());
|
66 |
+
if let Some(n_threads) = self.n_threads.as_ref() {
|
67 |
+
param.set_n_threads(*n_threads as i32);
|
68 |
+
}
|
69 |
+
if let Some(audio_ctx) = self.audio_ctx.as_ref() {
|
70 |
+
param.set_audio_ctx(*audio_ctx as i32);
|
71 |
+
}
|
72 |
+
if let Some(speed_up) = self.speed_up.as_ref() {
|
73 |
+
param.set_speed_up(*speed_up);
|
74 |
+
}
|
75 |
+
// param.set_tdrz_enable(self.tinydiarize);
|
76 |
+
if let Some(temperature_inc) = self.temperature_inc.as_ref() {
|
77 |
+
param.set_temperature_inc(*temperature_inc);
|
78 |
+
}
|
79 |
+
if let Some(suppress_non_speech_tokens) = self.suppress_non_speech_tokens.as_ref() {
|
80 |
+
param.set_suppress_non_speech_tokens(*suppress_non_speech_tokens);
|
81 |
+
}
|
82 |
+
if let Some(no_context) = self.no_context.as_ref() {
|
83 |
+
param.set_no_context(*no_context);
|
84 |
+
}
|
85 |
+
if let Some(entropy_threshold) = self.entropy_threshold.as_ref() {
|
86 |
+
param.set_entropy_thold(*entropy_threshold);
|
87 |
+
}
|
88 |
+
if let Some(n_max_text_ctx) = self.n_max_text_ctx.as_ref() {
|
89 |
+
param.set_n_max_text_ctx(*n_max_text_ctx as i32);
|
90 |
+
}
|
91 |
+
|
92 |
+
param.set_tokens(tokens);
|
93 |
+
param
|
94 |
+
}
|
95 |
+
}
|
{src → whisper/src}/group.rs
RENAMED
File without changes
|
src/whisper.rs → whisper/src/handler.rs
RENAMED
@@ -6,35 +6,47 @@ use std::{
|
|
6 |
};
|
7 |
use fvad::SampleRate;
|
8 |
|
9 |
-
use
|
10 |
-
use tokio::sync::{broadcast, mpsc, oneshot};
|
11 |
use tokio::time::Instant;
|
12 |
-
use tracing::{
|
13 |
-
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken, WhisperTokenData};
|
14 |
|
15 |
-
use crate::config::{Settings, SETTINGS};
|
16 |
use crate::{config::WhisperConfig, group::GroupedWithin};
|
17 |
|
18 |
const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
-
|
30 |
-
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
#[derive(Debug)]
|
34 |
-
pub
|
35 |
WhisperError {
|
36 |
description: String,
|
37 |
-
error:
|
38 |
},
|
39 |
}
|
40 |
|
@@ -97,22 +109,25 @@ pub struct WhisperHandler {
|
|
97 |
}
|
98 |
|
99 |
impl WhisperHandler {
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
|
102 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
103 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
|
104 |
let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
|
105 |
let shared_transcription_tx = transcription_tx.clone();
|
106 |
-
|
107 |
-
.create_state()
|
108 |
-
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
109 |
-
let preset_prompt_tokens = WHISPER_CONTEXT
|
110 |
-
.tokenize(prompt.as_str(), SETTINGS.whisper.max_prompt_tokens)
|
111 |
-
.map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
|
112 |
tokio::task::spawn_blocking(move || {
|
113 |
let mut vad = fvad::Fvad::new().expect("failed to create VAD")
|
114 |
.set_sample_rate(SampleRate::Rate16kHz);
|
115 |
-
let mut detector = Detector::new(state, &
|
116 |
let mut grouped = GroupedWithin::new(
|
117 |
detector.n_samples_step,
|
118 |
Duration::from_millis(config.step_ms as u64),
|
@@ -368,7 +383,9 @@ mod test {
|
|
368 |
use std::io::{stdout, Write};
|
369 |
use hound;
|
370 |
use tracing_test;
|
371 |
-
use tracing::info;
|
|
|
|
|
372 |
|
373 |
async fn print_output(output: Output) {
|
374 |
match output {
|
@@ -386,18 +403,50 @@ mod test {
|
|
386 |
}
|
387 |
stdout().flush().unwrap();
|
388 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
#[tokio::test]
|
390 |
#[tracing_test::traced_test]
|
391 |
async fn test_whisper_handler() {
|
392 |
-
let mut whisper_handler =
|
393 |
-
|
394 |
-
"
|
395 |
-
).expect("failed to create WhisperHandler");
|
396 |
|
397 |
let wav = hound::WavReader::open("samples/ADHD_1A.wav")
|
398 |
.expect("failed to open wav");
|
399 |
let spec = wav.spec();
|
400 |
-
|
401 |
let samples = wav
|
402 |
.into_samples::<i16>()
|
403 |
.map(|s| s.unwrap())
|
@@ -423,9 +472,9 @@ mod test {
|
|
423 |
|
424 |
match output {
|
425 |
Output::Stable(stable) => {
|
426 |
-
|
427 |
},
|
428 |
-
Output::Unstable(
|
429 |
|
430 |
}
|
431 |
}
|
|
|
6 |
};
|
7 |
use fvad::SampleRate;
|
8 |
|
9 |
+
use tokio::sync::{broadcast, mpsc, oneshot, OnceCell};
|
|
|
10 |
use tokio::time::Instant;
|
11 |
+
use tracing::{warn};
|
12 |
+
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperError, WhisperState, WhisperToken, WhisperTokenData};
|
13 |
|
|
|
14 |
use crate::{config::WhisperConfig, group::GroupedWithin};
|
15 |
|
16 |
const WHISPER_SAMPLE_RATE: usize = whisper_rs_sys::WHISPER_SAMPLE_RATE as usize;
|
17 |
|
18 |
+
pub struct Context {
|
19 |
+
context: WhisperContext,
|
20 |
+
}
|
21 |
+
|
22 |
+
impl <'a> Context {
|
23 |
+
pub fn new(model: &str) -> Result<Context, WhisperError> {
|
24 |
+
WhisperContext::new(model)
|
25 |
+
.map(|context| Self { context })
|
26 |
+
}
|
27 |
+
|
28 |
+
pub fn create_handler(&'static self, config: &'static WhisperConfig, prompt: String) -> Result<WhisperHandler, Error> {
|
29 |
+
WhisperHandler::new(&self.context, config, prompt)
|
30 |
}
|
31 |
+
}
|
32 |
+
|
33 |
+
static WHISPER_CONTEXT: OnceCell<WhisperContext> = OnceCell::const_new();
|
34 |
|
35 |
+
async fn initialize_whisper_context(model: String) -> WhisperContext {
|
36 |
+
tokio::task::spawn_blocking(move || {
|
37 |
+
WhisperContext::new(&model).expect("failed to create WhisperContext")
|
38 |
+
}).await.expect("failed to spawn")
|
39 |
+
}
|
40 |
+
|
41 |
+
async fn get_whisper_context(model: String) -> &'static WhisperContext {
|
42 |
+
WHISPER_CONTEXT.get_or_init(|| initialize_whisper_context(model)).await
|
43 |
+
}
|
44 |
|
45 |
#[derive(Debug)]
|
46 |
+
pub enum Error {
|
47 |
WhisperError {
|
48 |
description: String,
|
49 |
+
error: WhisperError,
|
50 |
},
|
51 |
}
|
52 |
|
|
|
109 |
}
|
110 |
|
111 |
impl WhisperHandler {
|
112 |
+
|
113 |
+
fn new(whisper_context: &'static WhisperContext, config: &'static WhisperConfig, prompt: String) -> Result<Self, Error> {
|
114 |
+
// let whisper_context = get_whisper_context(config.model.clone()).await;
|
115 |
+
let state = whisper_context
|
116 |
+
.create_state()
|
117 |
+
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
118 |
+
let preset_prompt_tokens = whisper_context
|
119 |
+
.tokenize(&prompt, config.max_prompt_tokens)
|
120 |
+
.map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
|
121 |
let vad_slice_size = WHISPER_SAMPLE_RATE / 100 * 3;
|
122 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
123 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<i16>>(128);
|
124 |
let (transcription_tx, _) = broadcast::channel::<Vec<Output>>(128);
|
125 |
let shared_transcription_tx = transcription_tx.clone();
|
126 |
+
|
|
|
|
|
|
|
|
|
|
|
127 |
tokio::task::spawn_blocking(move || {
|
128 |
let mut vad = fvad::Fvad::new().expect("failed to create VAD")
|
129 |
.set_sample_rate(SampleRate::Rate16kHz);
|
130 |
+
let mut detector = Detector::new(state, &config, preset_prompt_tokens);
|
131 |
let mut grouped = GroupedWithin::new(
|
132 |
detector.n_samples_step,
|
133 |
Duration::from_millis(config.step_ms as u64),
|
|
|
383 |
use std::io::{stdout, Write};
|
384 |
use hound;
|
385 |
use tracing_test;
|
386 |
+
use tracing::{info, debug};
|
387 |
+
use crate::config::WhisperParams;
|
388 |
+
use lazy_static::lazy_static;
|
389 |
|
390 |
async fn print_output(output: Output) {
|
391 |
match output {
|
|
|
403 |
}
|
404 |
stdout().flush().unwrap();
|
405 |
}
|
406 |
+
|
407 |
+
lazy_static! {
|
408 |
+
static ref CONFIG: WhisperConfig = WhisperConfig {
|
409 |
+
length_ms: 5000,
|
410 |
+
step_ms: 500,
|
411 |
+
keep_ms: 200,
|
412 |
+
model: "models/ggml-large-v3.bin".to_string(),
|
413 |
+
max_prompt_tokens: 32,
|
414 |
+
context_confidence_threshold: 0.5,
|
415 |
+
params: WhisperParams {
|
416 |
+
n_threads: None,
|
417 |
+
max_tokens: None,
|
418 |
+
audio_ctx: None,
|
419 |
+
speed_up: None,
|
420 |
+
translate: None,
|
421 |
+
no_context: None,
|
422 |
+
print_special: None,
|
423 |
+
print_realtime: None,
|
424 |
+
print_progress: None,
|
425 |
+
token_timestamps: None,
|
426 |
+
no_timestamps: None,
|
427 |
+
temperature_inc: None,
|
428 |
+
entropy_threshold: None,
|
429 |
+
single_segment: Some(true),
|
430 |
+
suppress_non_speech_tokens: None,
|
431 |
+
n_max_text_ctx: None,
|
432 |
+
language: Some("en".to_string()),
|
433 |
+
}
|
434 |
+
};
|
435 |
+
|
436 |
+
static ref CONTEXT: Context = Context::new(&CONFIG.model).expect("failed to create WhisperContext");
|
437 |
+
}
|
438 |
+
|
439 |
#[tokio::test]
|
440 |
#[tracing_test::traced_test]
|
441 |
async fn test_whisper_handler() {
|
442 |
+
let mut whisper_handler = CONTEXT
|
443 |
+
.create_handler(&CONFIG, "Harry Potter and the Philosopher's Stone".to_string())
|
444 |
+
.expect("failed to create WhisperHandler");
|
|
|
445 |
|
446 |
let wav = hound::WavReader::open("samples/ADHD_1A.wav")
|
447 |
.expect("failed to open wav");
|
448 |
let spec = wav.spec();
|
449 |
+
info!("{:?}", spec);
|
450 |
let samples = wav
|
451 |
.into_samples::<i16>()
|
452 |
.map(|s| s.unwrap())
|
|
|
472 |
|
473 |
match output {
|
474 |
Output::Stable(stable) => {
|
475 |
+
debug!("{}", stable.text);
|
476 |
},
|
477 |
+
Output::Unstable(_unstable) => {
|
478 |
|
479 |
}
|
480 |
}
|
whisper/src/main.rs
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
#[tokio::main]
|
4 |
+
async fn main() {
|
5 |
+
|
6 |
+
}
|