Fedir Zadniprovskyi commited on
Commit
9b178fc
·
1 Parent(s): bfdb6b8

feat: return 4xx on invalid files (#164)

Browse files
src/faster_whisper_server/routers/stt.py CHANGED
@@ -5,8 +5,10 @@ from io import BytesIO
5
  import logging
6
  from typing import TYPE_CHECKING, Annotated
7
 
 
8
  from fastapi import (
9
  APIRouter,
 
10
  Form,
11
  Query,
12
  Request,
@@ -15,9 +17,13 @@ from fastapi import (
15
  WebSocket,
16
  WebSocketDisconnect,
17
  )
 
18
  from fastapi.responses import StreamingResponse
19
  from fastapi.websockets import WebSocketState
 
20
  from faster_whisper.vad import VadOptions, get_speech_timestamps
 
 
21
  from pydantic import AfterValidator, Field
22
 
23
  from faster_whisper_server.api_models import (
@@ -51,6 +57,35 @@ logger = logging.getLogger(__name__)
51
  router = APIRouter()
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def segments_to_response(
55
  segments: Iterable[TranscriptionSegment],
56
  transcription_info: TranscriptionInfo,
@@ -140,7 +175,7 @@ ModelName = Annotated[
140
  def translate_file(
141
  config: ConfigDependency,
142
  model_manager: ModelManagerDependency,
143
- file: Annotated[UploadFile, Form()],
144
  model: Annotated[ModelName | None, Form()] = None,
145
  prompt: Annotated[str | None, Form()] = None,
146
  response_format: Annotated[ResponseFormat | None, Form()] = None,
@@ -154,7 +189,7 @@ def translate_file(
154
  response_format = config.default_response_format
155
  with model_manager.load_model(model) as whisper:
156
  segments, transcription_info = whisper.transcribe(
157
- file.file,
158
  task=Task.TRANSLATE,
159
  initial_prompt=prompt,
160
  temperature=temperature,
@@ -190,7 +225,7 @@ def transcribe_file(
190
  config: ConfigDependency,
191
  model_manager: ModelManagerDependency,
192
  request: Request,
193
- file: Annotated[UploadFile, Form()],
194
  model: Annotated[ModelName | None, Form()] = None,
195
  language: Annotated[Language | None, Form()] = None,
196
  prompt: Annotated[str | None, Form()] = None,
@@ -218,7 +253,7 @@ def transcribe_file(
218
  )
219
  with model_manager.load_model(model) as whisper:
220
  segments, transcription_info = whisper.transcribe(
221
- file.file,
222
  task=Task.TRANSCRIBE,
223
  language=language,
224
  initial_prompt=prompt,
 
5
  import logging
6
  from typing import TYPE_CHECKING, Annotated
7
 
8
+ import av.error
9
  from fastapi import (
10
  APIRouter,
11
+ Depends,
12
  Form,
13
  Query,
14
  Request,
 
17
  WebSocket,
18
  WebSocketDisconnect,
19
  )
20
+ from fastapi.exceptions import HTTPException
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
+ from faster_whisper.audio import decode_audio
24
  from faster_whisper.vad import VadOptions, get_speech_timestamps
25
+ from numpy import float32
26
+ from numpy.typing import NDArray
27
  from pydantic import AfterValidator, Field
28
 
29
  from faster_whisper_server.api_models import (
 
57
  router = APIRouter()
58
 
59
 
60
+ # TODO: test async vs sync performance
61
+ def audio_file_dependency(
62
+ file: Annotated[UploadFile, Form()],
63
+ ) -> NDArray[float32]:
64
+ try:
65
+ audio = decode_audio(file.file)
66
+ except av.error.InvalidDataError as e:
67
+ raise HTTPException(
68
+ status_code=415,
69
+ detail="Failed to decode audio. The provided file type is not supported.",
70
+ ) from e
71
+ except av.error.ValueError as e:
72
+ raise HTTPException(
73
+ status_code=400,
74
+ # TODO: list supported file types
75
+ detail="Failed to decode audio. The provided file is likely empty.",
76
+ ) from e
77
+ except Exception as e:
78
+ logger.exception(
79
+ "Failed to decode audio. This is likely a bug. Please create an issue at https://github.com/fedirz/faster-whisper-server/issues/new."
80
+ )
81
+ raise HTTPException(status_code=500, detail="Failed to decode audio.") from e
82
+ else:
83
+ return audio # pyright: ignore reportReturnType
84
+
85
+
86
+ AudioFileDependency = Annotated[NDArray[float32], Depends(audio_file_dependency)]
87
+
88
+
89
  def segments_to_response(
90
  segments: Iterable[TranscriptionSegment],
91
  transcription_info: TranscriptionInfo,
 
175
  def translate_file(
176
  config: ConfigDependency,
177
  model_manager: ModelManagerDependency,
178
+ audio: AudioFileDependency,
179
  model: Annotated[ModelName | None, Form()] = None,
180
  prompt: Annotated[str | None, Form()] = None,
181
  response_format: Annotated[ResponseFormat | None, Form()] = None,
 
189
  response_format = config.default_response_format
190
  with model_manager.load_model(model) as whisper:
191
  segments, transcription_info = whisper.transcribe(
192
+ audio,
193
  task=Task.TRANSLATE,
194
  initial_prompt=prompt,
195
  temperature=temperature,
 
225
  config: ConfigDependency,
226
  model_manager: ModelManagerDependency,
227
  request: Request,
228
+ audio: AudioFileDependency,
229
  model: Annotated[ModelName | None, Form()] = None,
230
  language: Annotated[Language | None, Form()] = None,
231
  prompt: Annotated[str | None, Form()] = None,
 
253
  )
254
  with model_manager.load_model(model) as whisper:
255
  segments, transcription_info = whisper.transcribe(
256
+ audio,
257
  task=Task.TRANSCRIBE,
258
  language=language,
259
  initial_prompt=prompt,