mingyang91 commited on
Commit
5df3ca5
·
verified ·
1 Parent(s): 819a8c6
Files changed (3) hide show
  1. config.yaml +2 -2
  2. src/main.rs +3 -1
  3. src/whisper.rs +14 -6
config.yaml CHANGED
@@ -1,11 +1,11 @@
1
  server:
2
  port: 8080
3
- host: ::1
4
  whisper:
5
  length_ms: 10000
6
  keep_ms: 200
7
  step_ms: 5000
8
- model: "models/ggml-base.bin"
9
  max_prompt_tokens: 128
10
  params:
11
  #n_threads: 4
 
1
  server:
2
  port: 8080
3
+ host: 0.0.0.0
4
  whisper:
5
  length_ms: 10000
6
  keep_ms: 200
7
  step_ms: 5000
8
+ model: "models/ggml-medium.en.bin" #"models/ggml-base.bin"
9
  max_prompt_tokens: 128
10
  params:
11
  #n_threads: 4
src/main.rs CHANGED
@@ -92,16 +92,18 @@ async fn main() -> Result<(), std::io::Error> {
92
  pub struct LessonSpeakerQuery {
93
  id: u32,
94
  lang: String,
 
95
  }
96
 
97
  #[handler]
98
  async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, ws: WebSocket) -> impl IntoResponse {
99
  let lesson = ctx.lessons_manager.create_lesson(query.id, query.lang.clone().parse().expect("Not supported lang")).await;
 
100
 
101
  ws.on_upgrade(|mut socket| async move {
102
  let origin_tx = lesson.voice_channel();
103
  let mut transcribe_rx = lesson.transcript_channel();
104
- let whisper = WhisperHandler::new(CONFIG.whisper.clone()).expect("failed to create whisper");
105
  let mut whisper_transcribe_rx = whisper.subscribe();
106
  loop {
107
  select! {
 
92
  pub struct LessonSpeakerQuery {
93
  id: u32,
94
  lang: String,
95
+ prompt: Option<String>,
96
  }
97
 
98
  #[handler]
99
  async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, ws: WebSocket) -> impl IntoResponse {
100
  let lesson = ctx.lessons_manager.create_lesson(query.id, query.lang.clone().parse().expect("Not supported lang")).await;
101
+ let prompt = query.prompt.clone().unwrap_or_default();
102
 
103
  ws.on_upgrade(|mut socket| async move {
104
  let origin_tx = lesson.voice_channel();
105
  let mut transcribe_rx = lesson.transcript_channel();
106
+ let whisper = WhisperHandler::new(CONFIG.whisper.clone(), prompt).expect("failed to create whisper");
107
  let mut whisper_transcribe_rx = whisper.subscribe();
108
  loop {
109
  select! {
src/whisper.rs CHANGED
@@ -5,9 +5,9 @@ use std::thread::sleep;
5
  use std::time::Duration;
6
  use lazy_static::lazy_static;
7
  use tokio::sync::{broadcast, mpsc, oneshot};
8
- use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext};
9
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
10
- use crate::config::{WhisperParams, CONFIG, WhisperConfig};
11
  use crate::group::GroupedWithin;
12
 
13
  lazy_static! {
@@ -79,15 +79,18 @@ pub struct WhisperHandler {
79
  }
80
 
81
  impl WhisperHandler {
82
- pub(crate) fn new(config: WhisperConfig) -> Result<Self, Error> {
83
  let (stop_handle, mut stop_signal) = oneshot::channel();
84
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
85
  let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
86
  let shared_transcription_tx = transcription_tx.clone();
87
  let state = WHISPER_CONTEXT.create_state()
88
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
 
 
 
89
  tokio::task::spawn_blocking(move || {
90
- let mut detector = Detector::new(state, &CONFIG.whisper);
91
  let mut grouped = GroupedWithin::new(
92
  detector.n_samples_step * 2,
93
  Duration::from_millis(config.step_ms as u64),
@@ -154,6 +157,7 @@ impl WhisperHandler {
154
  struct Detector {
155
  state: WhisperState<'static>,
156
  config: &'static WhisperConfig,
 
157
  n_samples_keep: usize,
158
  n_samples_step: usize,
159
  n_samples_len: usize,
@@ -165,10 +169,13 @@ struct Detector {
165
 
166
  impl Detector {
167
  fn new(state: WhisperState<'static>,
168
- config: &'static WhisperConfig) -> Self {
 
 
169
  Detector {
170
  state,
171
  config,
 
172
  n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
173
  n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
174
  n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
@@ -189,7 +196,8 @@ impl Detector {
189
  }
190
 
191
  fn inference(&mut self) -> Result<Vec<Segment>, Error> {
192
- let params = self.config.params.to_full_params(self.prompt_tokens.as_slice());
 
193
  let start = std::time::Instant::now();
194
  let _ = self.state.full(params, self.pcm_f32.make_contiguous())
195
  .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;
 
5
  use std::time::Duration;
6
  use lazy_static::lazy_static;
7
  use tokio::sync::{broadcast, mpsc, oneshot};
8
+ use whisper_rs::{convert_integer_to_float_audio, WhisperState, WhisperContext, WhisperToken};
9
  use whisper_rs_sys::WHISPER_SAMPLE_RATE;
10
+ use crate::config::{CONFIG, WhisperConfig};
11
  use crate::group::GroupedWithin;
12
 
13
  lazy_static! {
 
79
  }
80
 
81
  impl WhisperHandler {
82
+ pub(crate) fn new(config: WhisperConfig, prompt: String) -> Result<Self, Error> {
83
  let (stop_handle, mut stop_signal) = oneshot::channel();
84
  let (pcm_tx, pcm_rx) = mpsc::channel::<Vec<u8>>(128);
85
  let (transcription_tx, _) = broadcast::channel::<Vec<Segment>>(128);
86
  let shared_transcription_tx = transcription_tx.clone();
87
  let state = WHISPER_CONTEXT.create_state()
88
  .map_err(|e| Error::whisper_error("failed to create WhisperState", e))?;
89
+ let preset_prompt_tokens = WHISPER_CONTEXT
90
+ .tokenize(prompt.as_str(), CONFIG.whisper.max_prompt_tokens)
91
+ .map_err(|e| Error::whisper_error("failed to tokenize prompt", e))?;
92
  tokio::task::spawn_blocking(move || {
93
+ let mut detector = Detector::new(state, &CONFIG.whisper, preset_prompt_tokens);
94
  let mut grouped = GroupedWithin::new(
95
  detector.n_samples_step * 2,
96
  Duration::from_millis(config.step_ms as u64),
 
157
  struct Detector {
158
  state: WhisperState<'static>,
159
  config: &'static WhisperConfig,
160
+ preset_prompt_tokens: Vec<WhisperToken>,
161
  n_samples_keep: usize,
162
  n_samples_step: usize,
163
  n_samples_len: usize,
 
169
 
170
  impl Detector {
171
  fn new(state: WhisperState<'static>,
172
+ config: &'static WhisperConfig,
173
+ preset_prompt_tokens: Vec<WhisperToken>) -> Self {
174
+
175
  Detector {
176
  state,
177
  config,
178
+ preset_prompt_tokens,
179
  n_samples_keep: (config.keep_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
180
  n_samples_step: (config.step_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
181
  n_samples_len: (config.length_ms * WHISPER_SAMPLE_RATE / 1000) as usize,
 
196
  }
197
 
198
  fn inference(&mut self) -> Result<Vec<Segment>, Error> {
199
+ let prompt_tokens = [self.preset_prompt_tokens.as_slice(), self.prompt_tokens.as_slice()].concat();
200
+ let params = self.config.params.to_full_params(prompt_tokens.as_slice());
201
  let start = std::time::Instant::now();
202
  let _ = self.state.full(params, self.pcm_f32.make_contiguous())
203
  .map_err(|e| Error::whisper_error("failed to initialize WhisperState", e))?;