Fedir Zadniprovskyi commited on
Commit
3a0bd05
·
1 Parent(s): 9b178fc

feat: support BatchedInferencePipeline (#169)

Browse files
pyproject.toml CHANGED
@@ -6,7 +6,7 @@ requires-python = ">=3.12,<3.13"
6
  dependencies = [
7
  "ctranslate2>=4.5.0",
8
  "fastapi>=0.115.0",
9
- "faster-whisper>=1.0.3",
10
  "huggingface-hub>=0.25.1",
11
  "numpy>=2.1.1",
12
  "piper-phonemize ; platform_machine == 'x86_64'",
 
6
  dependencies = [
7
  "ctranslate2>=4.5.0",
8
  "fastapi>=0.115.0",
9
+ "faster-whisper>=1.1.0",
10
  "huggingface-hub>=0.25.1",
11
  "numpy>=2.1.1",
12
  "piper-phonemize ; platform_machine == 'x86_64'",
src/faster_whisper_server/asr.py CHANGED
@@ -31,6 +31,7 @@ class FasterWhisperASR:
31
  prompt: str | None = None,
32
  ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
33
  start = time.perf_counter()
 
34
  segments, transcription_info = self.whisper.transcribe(
35
  audio.data,
36
  initial_prompt=prompt,
 
31
  prompt: str | None = None,
32
  ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
33
  start = time.perf_counter()
34
+ # NOTE: should `BatchedInferencePipeline` be used here?
35
  segments, transcription_info = self.whisper.transcribe(
36
  audio.data,
37
  initial_prompt=prompt,
src/faster_whisper_server/config.py CHANGED
@@ -168,6 +168,10 @@ class WhisperConfig(BaseModel):
168
  -1: Never unload the model.
169
  0: Unload the model immediately after usage.
170
  """
 
 
 
 
171
 
172
 
173
  class Config(BaseSettings):
 
168
  -1: Never unload the model.
169
  0: Unload the model immediately after usage.
170
  """
171
+ use_batched_mode: bool = False
172
+ """
173
+ Whether to use batch mode(introduced in 1.1.0 `faster-whisper` release) for inference. This will likely become the default in the future and the configuration option will be removed.
174
+ """ # noqa: E501
175
 
176
 
177
  class Config(BaseSettings):
src/faster_whisper_server/routers/stt.py CHANGED
@@ -21,6 +21,7 @@ from fastapi.exceptions import HTTPException
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
  from faster_whisper.audio import decode_audio
 
24
  from faster_whisper.vad import VadOptions, get_speech_timestamps
25
  from numpy import float32
26
  from numpy.typing import NDArray
@@ -188,7 +189,8 @@ def translate_file(
188
  if response_format is None:
189
  response_format = config.default_response_format
190
  with model_manager.load_model(model) as whisper:
191
- segments, transcription_info = whisper.transcribe(
 
192
  audio,
193
  task=Task.TRANSLATE,
194
  initial_prompt=prompt,
@@ -252,7 +254,8 @@ def transcribe_file(
252
  "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
253
  )
254
  with model_manager.load_model(model) as whisper:
255
- segments, transcription_info = whisper.transcribe(
 
256
  audio,
257
  task=Task.TRANSCRIBE,
258
  language=language,
 
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
  from faster_whisper.audio import decode_audio
24
+ from faster_whisper.transcribe import BatchedInferencePipeline
25
  from faster_whisper.vad import VadOptions, get_speech_timestamps
26
  from numpy import float32
27
  from numpy.typing import NDArray
 
189
  if response_format is None:
190
  response_format = config.default_response_format
191
  with model_manager.load_model(model) as whisper:
192
+ whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
193
+ segments, transcription_info = whisper_model.transcribe(
194
  audio,
195
  task=Task.TRANSLATE,
196
  initial_prompt=prompt,
 
254
  "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
255
  )
256
  with model_manager.load_model(model) as whisper:
257
+ whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
258
+ segments, transcription_info = whisper_model.transcribe(
259
  audio,
260
  task=Task.TRANSCRIBE,
261
  language=language,
uv.lock CHANGED
@@ -230,7 +230,7 @@ wheels = [
230
 
231
  [[package]]
232
  name = "faster-whisper"
233
- version = "1.0.3"
234
  source = { registry = "https://pypi.org/simple" }
235
  dependencies = [
236
  { name = "av" },
@@ -238,10 +238,11 @@ dependencies = [
238
  { name = "huggingface-hub" },
239
  { name = "onnxruntime" },
240
  { name = "tokenizers" },
 
241
  ]
242
- sdist = { url = "https://files.pythonhosted.org/packages/1e/f2/77437ee937233d6e8259e3df511a4662cd7e833dabeaaddbfc929d2a3ed5/faster-whisper-1.0.3.tar.gz", hash = "sha256:1a145db86450b56aaa623c8df7d4ef86e8a1159900f60533e2890e98e8453a17", size = 1980019 }
243
  wheels = [
244
- { url = "https://files.pythonhosted.org/packages/7f/00/4742b1cd3afd23d0ff9b7e72ec40b2c398988332a5578115728fd83415d1/faster_whisper-1.0.3-py3-none-any.whl", hash = "sha256:364d0e378ab232ed26f39656e5c98548b38045224e206b20f7d8c90e2745b9d3", size = 1974982 },
245
  ]
246
 
247
  [[package]]
@@ -295,7 +296,7 @@ requires-dist = [
295
  { name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18.0" },
296
  { name = "ctranslate2", specifier = ">=4.5.0" },
297
  { name = "fastapi", specifier = ">=0.115.0" },
298
- { name = "faster-whisper", specifier = ">=1.0.3" },
299
  { name = "gradio", marker = "extra == 'ui'", specifier = ">=5.0.2" },
300
  { name = "httpx", marker = "extra == 'ui'", specifier = ">=0.27.2" },
301
  { name = "httpx-sse", marker = "extra == 'ui'", specifier = ">=0.4.0" },
 
230
 
231
  [[package]]
232
  name = "faster-whisper"
233
+ version = "1.1.0"
234
  source = { registry = "https://pypi.org/simple" }
235
  dependencies = [
236
  { name = "av" },
 
238
  { name = "huggingface-hub" },
239
  { name = "onnxruntime" },
240
  { name = "tokenizers" },
241
+ { name = "tqdm" },
242
  ]
243
+ sdist = { url = "https://files.pythonhosted.org/packages/31/b1/124f6d5a547756170e11eea405ae6c08afa2b96e8ccd10947a1244b50cdb/faster-whisper-1.1.0.tar.gz", hash = "sha256:cea4bba5d4527173fdbacafa56f2ffb17dd322688f6c3fdf5fd7b6b6c193ce17", size = 1124950 }
244
  wheels = [
245
+ { url = "https://files.pythonhosted.org/packages/7b/03/ab118cb743dcf671da01ad0cfd7564465dda115db32976fdc95e21ce8feb/faster_whisper-1.1.0-py3-none-any.whl", hash = "sha256:0f2d025676bbff1e46c4108b6f9a82578d6e33826c174af2990e45b33fab6182", size = 1118168 },
246
  ]
247
 
248
  [[package]]
 
296
  { name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18.0" },
297
  { name = "ctranslate2", specifier = ">=4.5.0" },
298
  { name = "fastapi", specifier = ">=0.115.0" },
299
+ { name = "faster-whisper", specifier = ">=1.1.0" },
300
  { name = "gradio", marker = "extra == 'ui'", specifier = ">=5.0.2" },
301
  { name = "httpx", marker = "extra == 'ui'", specifier = ">=0.27.2" },
302
  { name = "httpx-sse", marker = "extra == 'ui'", specifier = ">=0.4.0" },