Fedir Zadniprovskyi
feat: support loading multiple models
aada575
raw
history blame
10.6 kB
from __future__ import annotations
import asyncio
import time
from contextlib import asynccontextmanager
from io import BytesIO
from typing import Annotated, Literal, OrderedDict
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
WebSocketDisconnect)
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocketState
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
from speaches import utils
from speaches.asr import FasterWhisperASR
from speaches.audio import AudioStream, audio_samples_from_file
from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
ResponseFormat, config)
from speaches.logger import logger
from speaches.server_models import (TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse)
from speaches.transcriber import audio_transcriber
models: OrderedDict[Model, WhisperModel] = OrderedDict()
def load_model(model_name: Model) -> WhisperModel:
if model_name in models:
logger.debug(f"{model_name} model already loaded")
return models[model_name]
if len(models) >= config.max_models:
oldest_model_name = next(iter(models))
logger.info(
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
)
del models[oldest_model_name]
logger.debug(f"Loading {model_name}")
start = time.perf_counter()
whisper = WhisperModel(
model_name,
device=config.whisper.inference_device,
compute_type=config.whisper.compute_type,
)
logger.info(
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds"
)
models[model_name] = whisper
return whisper
@asynccontextmanager
async def lifespan(_: FastAPI):
load_model(config.whisper.model)
yield
for model in models.keys():
logger.info(f"Unloading {model}")
del models[model]
app = FastAPI(lifespan=lifespan)
@app.get("/health")
def health() -> Response:
return Response(status_code=200, content="Everything is peachy!")
@app.post("/v1/audio/translations")
def translate_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
stream: Annotated[bool, Form()] = False,
):
start = time.perf_counter()
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task="translate",
initial_prompt=prompt,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
# https://platform.openai.com/docs/api-reference/audio/createTranscription
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
@app.post("/v1/audio/transcriptions")
def transcribe_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
language: Annotated[Language | None, Form()] = config.default_language,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segments"] | Literal["words"]],
Form(alias="timestamp_granularities[]"),
] = ["segments"],
stream: Annotated[bool, Form()] = False,
):
start = time.perf_counter()
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task="transcribe",
language=language,
initial_prompt=prompt,
word_timestamps="words" in timestamp_granularities,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
logger.info(
f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
try:
while True:
bytes_ = await asyncio.wait_for(
ws.receive_bytes(), timeout=config.max_no_data_seconds
)
logger.debug(f"Received {len(bytes_)} bytes of audio data")
audio_samples = audio_samples_from_file(BytesIO(bytes_))
audio_stream.extend(audio_samples)
if audio_stream.duration - config.inactivity_window_seconds >= 0:
audio = audio_stream.after(
audio_stream.duration - config.inactivity_window_seconds
)
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
# NOTE: This is a synchronous operation that runs every time new data is received.
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
timestamps = get_speech_timestamps(audio.data, vad_opts)
if len(timestamps) == 0:
logger.info(
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
)
break
elif (
# last speech end time
config.inactivity_window_seconds
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
>= config.max_inactivity_seconds
):
logger.info(
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
)
break
except asyncio.TimeoutError:
logger.info(
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
)
except WebSocketDisconnect as e:
logger.info(f"Client disconnected: {e}")
audio_stream.close()
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
model: Annotated[Model, Query()] = config.whisper.model,
language: Annotated[Language | None, Query()] = config.default_language,
prompt: Annotated[str | None, Query()] = None,
response_format: Annotated[
ResponseFormat, Query()
] = config.default_response_format,
temperature: Annotated[float, Query()] = 0.0,
) -> None:
await ws.accept()
transcribe_opts = {
"language": language,
"initial_prompt": prompt,
"temperature": temperature,
"vad_filter": True,
"condition_on_previous_text": False,
}
whisper = load_model(model)
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(
TranscriptionJsonResponse.from_transcription(
transcription
).model_dump()
)
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
TranscriptionVerboseJsonResponse.from_transcription(
transcription
).model_dump()
)
if not ws.client_state == WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
await ws.close()