Spaces:
Sleeping
Sleeping
Prompt
Browse files- config.yaml +2 -2
- src/main.rs +3 -1
- src/whisper.rs +14 -6
config.yaml
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
server:
|
2 |
port: 8080
|
3 |
-
host:
|
4 |
whisper:
|
5 |
length_ms: 10000
|
6 |
keep_ms: 200
|
7 |
step_ms: 5000
|
8 |
-
model: "models/ggml-base.bin"
|
9 |
max_prompt_tokens: 128
|
10 |
params:
|
11 |
#n_threads: 4
|
|
|
1 |
server:
|
2 |
port: 8080
|
3 |
+
host: 0.0.0.0
|
4 |
whisper:
|
5 |
length_ms: 10000
|
6 |
keep_ms: 200
|
7 |
step_ms: 5000
|
8 |
+
model: "models/ggml-medium.en.bin" #"models/ggml-base.bin"
|
9 |
max_prompt_tokens: 128
|
10 |
params:
|
11 |
#n_threads: 4
|
src/main.rs
CHANGED
@@ -92,16 +92,18 @@ async fn main() -> Result<(), std::io::Error> {
|
|
92 |
pub struct LessonSpeakerQuery {
|
93 |
id: u32,
|
94 |
lang: String,
|
|
|
95 |
}
|
96 |
|
97 |
#[handler]
|
98 |
async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, ws: WebSocket) -> impl IntoResponse {
|
99 |
let lesson = ctx.lessons_manager.create_lesson(query.id, query.lang.clone().parse().expect("Not supported lang")).await;
|
|
|
100 |
|
101 |
ws.on_upgrade(|mut socket| async move {
|
102 |
let origin_tx = lesson.voice_channel();
|
103 |
let mut transcribe_rx = lesson.transcript_channel();
|
104 |
-
let whisper = WhisperHandler::new(CONFIG.whisper.clone()).expect("failed to create whisper");
|
105 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
106 |
loop {
|
107 |
select! {
|
|
|
92 |
pub struct LessonSpeakerQuery {
|
93 |
id: u32,
|
94 |
lang: String,
|
95 |
+
prompt: Option<String>,
|
96 |
}
|
97 |
|
98 |
#[handler]
|
99 |
async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, ws: WebSocket) -> impl IntoResponse {
|
100 |
let lesson = ctx.lessons_manager.create_lesson(query.id, query.lang.clone().parse().expect("Not supported lang")).await;
|
101 |
+
let prompt = query.prompt.clone().unwrap_or_default();
|
102 |
|
103 |
ws.on_upgrade(|mut socket| async move {
|
104 |
let origin_tx = lesson.voice_channel();
|
105 |
let mut transcribe_rx = lesson.transcript_channel();
|
106 |
+
let whisper = WhisperHandler::new(CONFIG.whisper.clone(), prompt).expect("failed to create whisper");
|
107 |
let mut whisper_transcribe_rx = whisper.subscribe();
|
108 |
loop {
|
109 |
select! {
|
src/whisper.rs
CHANGED
@@ -5,9 +5,9 @@ use std::thread::sleep;
|
|
5 |
use std::time::Duration;
|
6 |
use lazy_static::lazy_static;
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
-
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
-
use crate::config::{
|
11 |
use crate::group::GroupedWithin;
|
12 |
|
13 |
lazy_static! {
|
@@ -79,15 +79,18 @@ pub struct WhisperHandler {
|
|
79 |
}
|
80 |
|
81 |
impl WhisperHandler {
|
82 |
-
pub(crate) fn new(config: WhisperConfig) -> Result<Self, Error> {
|
83 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
84 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
85 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
86 |
let shared_transcription_tx = transcription_tx.clone();
|
87 |
let state = WHISPER_CONTEXT.create_state()
|
88 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
|
|
|
|
|
|
89 |
tokio::task::spawn_blocking(move || {
|
90 |
-
let mut detector = Detector::new(state, &CONFIG.whisper);
|
91 |
let mut grouped = GroupedWithin::new(
|
92 |
detector.n_samples_step * 2,
|
93 |
Duration::from_millis(config.step_ms as u64),
|
@@ -154,6 +157,7 @@ impl WhisperHandler {
|
|
154 |
struct Detector {
|
155 |
state: WhisperState<'static>,
|
156 |
config: &'static WhisperConfig,
|
|
|
157 |
n_samples_keep: usize,
|
158 |
n_samples_step: usize,
|
159 |
n_samples_len: usize,
|
@@ -165,10 +169,13 @@ struct Detector {
|
|
165 |
|
166 |
impl Detector {
|
167 |
fn new(state: WhisperState<'static>,
|
168 |
-
config: &'static WhisperConfig
|
|
|
|
|
169 |
Detector {
|
170 |
state,
|
171 |
config,
|
|
|
172 |
n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
173 |
n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
174 |
n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
@@ -189,7 +196,8 @@ impl Detector {
|
|
189 |
}
|
190 |
|
191 |
fn inference(&mut self) -> Result<Vec<Segment>, Error> {
|
192 |
-
let
|
|
|
193 |
let start = std::time::Instant::now();
|
194 |
let _ = self.state.full(params, self.pcm_f32.make_contiguous())
|
195 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|
|
|
5 |
use std::time::Duration;
|
6 |
use lazy_static::lazy_static;
|
7 |
use tokio::sync::{broadcast, mpsc, oneshot};
|
8 |
+
use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext, WhisperToken};
|
9 |
use whisper_rs_sys::WHISPER_SAMPLE_RATE;
|
10 |
+
use crate::config::{CONFIG, WhisperConfig};
|
11 |
use crate::group::GroupedWithin;
|
12 |
|
13 |
lazy_static! {
|
|
|
79 |
}
|
80 |
|
81 |
impl WhisperHandler {
|
82 |
+
pub(crate) fn new(config: WhisperConfig, prompt: String) -> Result<Self, Error> {
|
83 |
let (stop_handle, mut stop_signal) = oneshot::channel();
|
84 |
let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
|
85 |
let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
|
86 |
let shared_transcription_tx = transcription_tx.clone();
|
87 |
let state = WHISPER_CONTEXT.create_state()
|
88 |
.map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
|
89 |
+
let preset_prompt_tokens = WHISPER_CONTEXT
|
90 |
+
.tokenize(prompt.as_str(), CONFIG.whisper.max_prompt_tokens)
|
91 |
+
.map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
|
92 |
tokio::task::spawn_blocking(move || {
|
93 |
+
let mut detector = Detector::new(state, &CONFIG.whisper, preset_prompt_tokens);
|
94 |
let mut grouped = GroupedWithin::new(
|
95 |
detector.n_samples_step * 2,
|
96 |
Duration::from_millis(config.step_ms as u64),
|
|
|
157 |
struct Detector {
|
158 |
state: WhisperState<'static>,
|
159 |
config: &'static WhisperConfig,
|
160 |
+
preset_prompt_tokens: Vec<WhisperToken>,
|
161 |
n_samples_keep: usize,
|
162 |
n_samples_step: usize,
|
163 |
n_samples_len: usize,
|
|
|
169 |
|
170 |
impl Detector {
|
171 |
fn new(state: WhisperState<'static>,
|
172 |
+
config: &'static WhisperConfig,
|
173 |
+
preset_prompt_tokens: Vec<WhisperToken>) -> Self {
|
174 |
+
|
175 |
Detector {
|
176 |
state,
|
177 |
config,
|
178 |
+
preset_prompt_tokens,
|
179 |
n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
180 |
n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
181 |
n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
|
|
|
196 |
}
|
197 |
|
198 |
fn inference(&mut self) -> Result<Vec<Segment>, Error> {
|
199 |
+
let prompt_tokens = [self.preset_prompt_tokens.as_slice(), self.prompt_tokens.as_slice()].concat();
|
200 |
+
let params = self.config.params.to_full_params(prompt_tokens.as_slice());
|
201 |
let start = std::time::Instant::now();
|
202 |
let _ = self.state.full(params, self.pcm_f32.make_contiguous())
|
203 |
.map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
|