Spaces:
Sleeping
Sleeping
Matrix
commited on
Commit
·
3569cbd
1
Parent(s):
9ce120f
chore: use once_cell lazy instead of lazy_static
Browse files- Cargo.lock +1 -1
- Cargo.toml +1 -1
- src/config.rs +6 -20
- src/group.rs +1 -1
- src/lesson.rs +3 -3
- src/main.rs +16 -14
- src/whisper.rs +29 -25
Cargo.lock
CHANGED
@@ -1449,7 +1449,7 @@ dependencies = [
|
|
1449 |
"aws-sdk-translate",
|
1450 |
"config",
|
1451 |
"futures-util",
|
1452 |
-
"
|
1453 |
"poem",
|
1454 |
"serde",
|
1455 |
"serde_json",
|
|
|
1449 |
"aws-sdk-translate",
|
1450 |
"config",
|
1451 |
"futures-util",
|
1452 |
+
"once_cell",
|
1453 |
"poem",
|
1454 |
"serde",
|
1455 |
"serde_json",
|
Cargo.toml
CHANGED
@@ -12,7 +12,7 @@ aws-sdk-translate = "0.34"
|
|
12 |
aws-sdk-polly = "0.34"
|
13 |
config = "0.13"
|
14 |
futures-util = "0.3"
|
15 |
-
|
16 |
serde = { version = "1.0", features = ["derive"] }
|
17 |
serde_json = "1.0"
|
18 |
serde_yaml = "0.9"
|
|
|
12 |
aws-sdk-polly = "0.34"
|
13 |
config = "0.13"
|
14 |
futures-util = "0.3"
|
15 |
+
once_cell = "1.18"
|
16 |
serde = { version = "1.0", features = ["derive"] }
|
17 |
serde_json = "1.0"
|
18 |
serde_yaml = "0.9"
|
src/config.rs
CHANGED
@@ -1,26 +1,12 @@
|
|
1 |
-
use std::{ffi::c_int, net::IpAddr
|
2 |
|
3 |
-
use config::{
|
4 |
-
use
|
5 |
use serde::Deserialize;
|
6 |
use whisper_rs::FullParams;
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
IoError(std::io::Error),
|
11 |
-
ConfigError(serde_yaml::Error),
|
12 |
-
}
|
13 |
-
|
14 |
-
pub(crate) fn load_config() -> Result<Settings, Error> {
|
15 |
-
let config_str = std::fs::read_to_string("../config/dev.yaml").map_err(|e| Error::IoError(e))?;
|
16 |
-
let config: Settings =
|
17 |
-
serde_yaml::from_str(config_str.as_str()).map_err(|e| Error::ConfigError(e))?;
|
18 |
-
return Ok(config);
|
19 |
-
}
|
20 |
-
|
21 |
-
lazy_static! {
|
22 |
-
pub static ref CONFIG: Settings = load_config().expect("failed to load config");
|
23 |
-
}
|
24 |
|
25 |
#[derive(Debug, Deserialize, Clone)]
|
26 |
pub(crate) struct WhisperConfig {
|
@@ -62,7 +48,7 @@ impl WhisperParams {
|
|
62 |
param.set_translate(self.translate);
|
63 |
param.set_single_segment(false);
|
64 |
param.set_max_tokens(self.max_tokens as i32);
|
65 |
-
let lang = self.language.
|
66 |
param.set_language(lang);
|
67 |
let num_cpus = std::thread::available_parallelism()
|
68 |
.map(|c| c.get())
|
|
|
1 |
+
use std::{env, ffi::c_int, net::IpAddr};
|
2 |
|
3 |
+
use config::{Config, Environment, File};
|
4 |
+
use once_cell::sync::Lazy;
|
5 |
use serde::Deserialize;
|
6 |
use whisper_rs::FullParams;
|
7 |
|
8 |
+
pub(crate) static SETTINGS: Lazy<Settings> =
|
9 |
+
Lazy::new(|| Settings::new().expect("Failed to initialize settings"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
#[derive(Debug, Deserialize, Clone)]
|
12 |
pub(crate) struct WhisperConfig {
|
|
|
48 |
param.set_translate(self.translate);
|
49 |
param.set_single_segment(false);
|
50 |
param.set_max_tokens(self.max_tokens as i32);
|
51 |
+
let lang = self.language.as_deref();
|
52 |
param.set_language(lang);
|
53 |
let num_cpus = std::thread::available_parallelism()
|
54 |
.map(|c| c.get())
|
src/group.rs
CHANGED
@@ -35,7 +35,7 @@ where
|
|
35 |
return Some(will_send);
|
36 |
}
|
37 |
}
|
38 |
-
|
39 |
};
|
40 |
|
41 |
let grouped: Vec<I> = select! {
|
|
|
35 |
return Some(will_send);
|
36 |
}
|
37 |
}
|
38 |
+
None
|
39 |
};
|
40 |
|
41 |
let grouped: Vec<I> = select! {
|
src/lesson.rs
CHANGED
@@ -29,9 +29,9 @@ pub struct LessonsManager {
|
|
29 |
|
30 |
impl LessonsManager {
|
31 |
pub(crate) fn new(sdk_config: &SdkConfig) -> Self {
|
32 |
-
let transcript_client = aws_sdk_transcribestreaming::Client::new(
|
33 |
-
let translate_client = aws_sdk_translate::Client::new(
|
34 |
-
let polly_client = aws_sdk_polly::Client::new(
|
35 |
LessonsManager {
|
36 |
translate_client,
|
37 |
polly_client,
|
|
|
29 |
|
30 |
impl LessonsManager {
|
31 |
pub(crate) fn new(sdk_config: &SdkConfig) -> Self {
|
32 |
+
let transcript_client = aws_sdk_transcribestreaming::Client::new(sdk_config);
|
33 |
+
let translate_client = aws_sdk_translate::Client::new(sdk_config);
|
34 |
+
let polly_client = aws_sdk_polly::Client::new(sdk_config);
|
35 |
LessonsManager {
|
36 |
translate_client,
|
37 |
polly_client,
|
src/main.rs
CHANGED
@@ -5,21 +5,23 @@
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
8 |
-
use aws_sdk_transcribestreaming::
|
9 |
-
use futures_util::stream::StreamExt;
|
10 |
-
use
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
use serde::{Deserialize, Serialize};
|
21 |
use tokio::select;
|
22 |
|
|
|
|
|
23 |
mod config;
|
24 |
mod group;
|
25 |
mod lesson;
|
@@ -61,7 +63,7 @@ async fn main() -> Result<(), std::io::Error> {
|
|
61 |
StaticFileEndpoint::new("./static/index.html"),
|
62 |
)
|
63 |
.data(ctx);
|
64 |
-
let addr = format!("{}:{}",
|
65 |
let listener = TcpListener::bind(addr);
|
66 |
let server = Server::new(listener);
|
67 |
|
@@ -94,7 +96,7 @@ async fn stream_speaker(
|
|
94 |
let _origin_tx = lesson.voice_channel();
|
95 |
let mut transcribe_rx = lesson.transcript_channel();
|
96 |
let whisper =
|
97 |
-
WhisperHandler::new(
|
98 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
99 |
loop {
|
100 |
select! {
|
|
|
5 |
|
6 |
#![allow(clippy::result_large_err)]
|
7 |
|
8 |
+
use aws_sdk_transcribestreaming::meta::PKG_VERSION;
|
9 |
+
use futures_util::{stream::StreamExt, SinkExt};
|
10 |
+
use poem::{
|
11 |
+
endpoint::{StaticFileEndpoint, StaticFilesEndpoint},
|
12 |
+
get, handler,
|
13 |
+
listener::TcpListener,
|
14 |
+
web::{
|
15 |
+
websocket::{Message, WebSocket},
|
16 |
+
Data, Query,
|
17 |
+
},
|
18 |
+
EndpointExt, IntoResponse, Route, Server,
|
19 |
+
};
|
20 |
use serde::{Deserialize, Serialize};
|
21 |
use tokio::select;
|
22 |
|
23 |
+
use crate::{config::*, lesson::*, whisper::*};
|
24 |
+
|
25 |
mod config;
|
26 |
mod group;
|
27 |
mod lesson;
|
|
|
63 |
StaticFileEndpoint::new("./static/index.html"),
|
64 |
)
|
65 |
.data(ctx);
|
66 |
+
let addr = format!("{}:{}", SETTINGS.server.host, SETTINGS.server.port);
|
67 |
let listener = TcpListener::bind(addr);
|
68 |
let server = Server::new(listener);
|
69 |
|
|
|
96 |
let _origin_tx = lesson.voice_channel();
|
97 |
let mut transcribe_rx = lesson.transcript_channel();
|
98 |
let whisper =
|
99 |
+
WhisperHandler::new(SETTINGS.whisper.clone(), prompt).expect("failed to create whisper");
|
100 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
101 |
loop {
|
102 |
select! {
|
src/whisper.rs
CHANGED
@@ -1,19 +1,23 @@
|
|
1 |
-
use
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
9 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
10 |
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken};
|
11 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
#[derive(Debug)]
|
19 |
pub(crate) enum Error {
|
@@ -86,10 +90,10 @@ impl WhisperHandler {
|
|
86 |
.create_state()
|
87 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
88 |
let preset_prompt_tokens = WHISPER_CONTEXT
|
89 |
-
.tokenize(prompt.as_str(),
|
90 |
.map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
|
91 |
tokio::task::spawn_blocking(move || {
|
92 |
-
let mut detector = Detector::new(state, &
|
93 |
let mut grouped = GroupedWithin::new(
|
94 |
detector.n_samples_step * 2,
|
95 |
Duration::from_millis(config.step_ms as u64),
|
@@ -203,7 +207,7 @@ impl Detector {
|
|
203 |
self.preset_prompt_tokens.as_slice(),
|
204 |
self.prompt_tokens.as_slice(),
|
205 |
]
|
206 |
-
|
207 |
let params = self.config.params.to_full_params(prompt_tokens.as_slice());
|
208 |
let start = std::time::Instant::now();
|
209 |
let _ = self
|
@@ -229,18 +233,18 @@ impl Detector {
|
|
229 |
for i in 0..num_segments {
|
230 |
let end_timestamp: i64 = timestamp_offset
|
231 |
+ 10 * self
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
if end_timestamp <= stable_offset {
|
236 |
continue;
|
237 |
}
|
238 |
|
239 |
let start_timestamp: i64 = timestamp_offset
|
240 |
+ 10 * self
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
let segment = self
|
245 |
.state
|
246 |
.full_get_segment_text(i)
|
@@ -285,12 +289,12 @@ impl Detector {
|
|
285 |
let Some(last) = stable_segments.last() else {
|
286 |
return;
|
287 |
};
|
288 |
-
let drop_offset: usize =
|
289 |
-
- self.offset
|
290 |
let len_to_drain = self.pcm_f32.drain(0..drop_offset).len();
|
291 |
self.offset += len_to_drain;
|
292 |
|
293 |
-
for segment in stable_segments.
|
294 |
self.prompt_tokens.extend(&segment.tokens);
|
295 |
}
|
296 |
if self.prompt_tokens.len() > self.config.max_prompt_tokens {
|
@@ -307,7 +311,7 @@ impl Drop for WhisperHandler {
|
|
307 |
let Some(stop_handle) = self.stop_handle.take() else {
|
308 |
return tracing::warn!("WhisperHandler::drop() called without stop_handle");
|
309 |
};
|
310 |
-
if
|
311 |
tracing::warn!("WhisperHandler::drop() failed to send stop signal");
|
312 |
}
|
313 |
}
|
|
|
1 |
+
use std::{
|
2 |
+
collections::VecDeque,
|
3 |
+
ffi::c_int,
|
4 |
+
fmt::{Debug, Display, Formatter},
|
5 |
+
thread::sleep,
|
6 |
+
time::Duration,
|
7 |
+
};
|
8 |
+
|
9 |
+
use once_cell::sync::Lazy;
|
10 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
11 |
use whisper_rs::{convert_integer_to_float_audio, WhisperContext, WhisperState, WhisperToken};
|
12 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
13 |
|
14 |
+
use crate::config::{Settings, SETTINGS};
|
15 |
+
use crate::{config::WhisperConfig, group::GroupedWithin};
|
16 |
+
|
17 |
+
static WHISPER_CONTEXT: Lazy<WhisperContext> = Lazy::new(|| {
|
18 |
+
let settings = Settings::new().expect("Failed to initialize settings.");
|
19 |
+
WhisperContext::new(&settings.whisper.model).expect("failed to create WhisperContext")
|
20 |
+
});
|
21 |
|
22 |
#[derive(Debug)]
|
23 |
pub(crate) enum Error {
|
|
|
90 |
.create_state()
|
91 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
92 |
let preset_prompt_tokens = WHISPER_CONTEXT
|
93 |
+
.tokenize(prompt.as_str(), SETTINGS.whisper.max_prompt_tokens)
|
94 |
.map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
|
95 |
tokio::task::spawn_blocking(move || {
|
96 |
+
let mut detector = Detector::new(state, &SETTINGS.whisper, preset_prompt_tokens);
|
97 |
let mut grouped = GroupedWithin::new(
|
98 |
detector.n_samples_step * 2,
|
99 |
Duration::from_millis(config.step_ms as u64),
|
|
|
207 |
self.preset_prompt_tokens.as_slice(),
|
208 |
self.prompt_tokens.as_slice(),
|
209 |
]
|
210 |
+
.concat();
|
211 |
let params = self.config.params.to_full_params(prompt_tokens.as_slice());
|
212 |
let start = std::time::Instant::now();
|
213 |
let _ = self
|
|
|
233 |
for i in 0..num_segments {
|
234 |
let end_timestamp: i64 = timestamp_offset
|
235 |
+ 10 * self
|
236 |
+
.state
|
237 |
+
.full_get_segment_t1(i)
|
238 |
+
.map_err(|e| Error::whisper_error("failed to get end timestamp", e))?;
|
239 |
if end_timestamp <= stable_offset {
|
240 |
continue;
|
241 |
}
|
242 |
|
243 |
let start_timestamp: i64 = timestamp_offset
|
244 |
+ 10 * self
|
245 |
+
.state
|
246 |
+
.full_get_segment_t0(i)
|
247 |
+
.map_err(|e| Error::whisper_error("failed to get start timestamp", e))?;
|
248 |
let segment = self
|
249 |
.state
|
250 |
.full_get_segment_text(i)
|
|
|
289 |
let Some(last) = stable_segments.last() else {
|
290 |
return;
|
291 |
};
|
292 |
+
let drop_offset: usize =
|
293 |
+
last.end_timestamp as usize / 1000 * WHISPER_SAMPLE_RATE as usize - self.offset;
|
294 |
let len_to_drain = self.pcm_f32.drain(0..drop_offset).len();
|
295 |
self.offset += len_to_drain;
|
296 |
|
297 |
+
for segment in stable_segments.iter() {
|
298 |
self.prompt_tokens.extend(&segment.tokens);
|
299 |
}
|
300 |
if self.prompt_tokens.len() > self.config.max_prompt_tokens {
|
|
|
311 |
let Some(stop_handle) = self.stop_handle.take() else {
|
312 |
return tracing::warn!("WhisperHandler::drop() called without stop_handle");
|
313 |
};
|
314 |
+
if stop_handle.send(()).is_err() {
|
315 |
tracing::warn!("WhisperHandler::drop() failed to send stop signal");
|
316 |
}
|
317 |
}
|