from __future__ import annotations import asyncio import time from contextlib import asynccontextmanager from io import BytesIO from typing import Annotated, Literal, OrderedDict from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket, WebSocketDisconnect) from fastapi.responses import StreamingResponse from fastapi.websockets import WebSocketState from faster_whisper import WhisperModel from faster_whisper.vad import VadOptions, get_speech_timestamps from speaches import utils from speaches.asr import FasterWhisperASR from speaches.audio import AudioStream, audio_samples_from_file from speaches.config import (SAMPLES_PER_SECOND, Language, Model, ResponseFormat, config) from speaches.logger import logger from speaches.server_models import (TranscriptionJsonResponse, TranscriptionVerboseJsonResponse) from speaches.transcriber import audio_transcriber models: OrderedDict[Model, WhisperModel] = OrderedDict() def load_model(model_name: Model) -> WhisperModel: if model_name in models: logger.debug(f"{model_name} model already loaded") return models[model_name] if len(models) >= config.max_models: oldest_model_name = next(iter(models)) logger.info( f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" ) del models[oldest_model_name] logger.debug(f"Loading {model_name}") start = time.perf_counter() whisper = WhisperModel( model_name, device=config.whisper.inference_device, compute_type=config.whisper.compute_type, ) logger.info( f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds" ) models[model_name] = whisper return whisper @asynccontextmanager async def lifespan(_: FastAPI): load_model(config.whisper.model) yield for model in models.keys(): logger.info(f"Unloading {model}") del models[model] app = FastAPI(lifespan=lifespan) @app.get("/health") def health() -> Response: return Response(status_code=200, content="Everything is peachy!") @app.post("/v1/audio/translations") def translate_file( file: Annotated[UploadFile, Form()], model: Annotated[Model, Form()] = config.whisper.model, prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, temperature: Annotated[float, Form()] = 0.0, stream: Annotated[bool, Form()] = False, ): start = time.perf_counter() whisper = load_model(model) segments, transcription_info = whisper.transcribe( file.file, task="translate", initial_prompt=prompt, temperature=temperature, vad_filter=True, ) def segment_responses(): for segment in segments: if response_format == ResponseFormat.TEXT: yield segment.text elif response_format == ResponseFormat.JSON: yield TranscriptionJsonResponse.from_segments( [segment] ).model_dump_json() elif response_format == ResponseFormat.VERBOSE_JSON: yield TranscriptionVerboseJsonResponse.from_segment( segment, transcription_info ).model_dump_json() if not stream: segments = list(segments) logger.info( f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds" ) if response_format == ResponseFormat.TEXT: return utils.segments_text(segments) elif response_format == ResponseFormat.JSON: return TranscriptionJsonResponse.from_segments(segments) elif response_format == ResponseFormat.VERBOSE_JSON: return TranscriptionVerboseJsonResponse.from_segments( segments, transcription_info ) else: return StreamingResponse(segment_responses(), media_type="text/event-stream") # https://platform.openai.com/docs/api-reference/audio/createTranscription # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 @app.post("/v1/audio/transcriptions") def transcribe_file( file: Annotated[UploadFile, Form()], model: Annotated[Model, Form()] = config.whisper.model, language: Annotated[Language | None, Form()] = config.default_language, prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ list[Literal["segments"] | Literal["words"]], Form(alias="timestamp_granularities[]"), ] = ["segments"], stream: Annotated[bool, Form()] = False, ): start = time.perf_counter() whisper = load_model(model) segments, transcription_info = whisper.transcribe( file.file, task="transcribe", language=language, initial_prompt=prompt, word_timestamps="words" in timestamp_granularities, temperature=temperature, vad_filter=True, ) def segment_responses(): for segment in segments: logger.info( f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds" ) if response_format == ResponseFormat.TEXT: yield segment.text elif response_format == ResponseFormat.JSON: yield TranscriptionJsonResponse.from_segments( [segment] ).model_dump_json() elif response_format == ResponseFormat.VERBOSE_JSON: yield TranscriptionVerboseJsonResponse.from_segment( segment, transcription_info ).model_dump_json() if not stream: segments = list(segments) logger.info( f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds" ) if response_format == ResponseFormat.TEXT: return utils.segments_text(segments) elif response_format == ResponseFormat.JSON: return TranscriptionJsonResponse.from_segments(segments) elif response_format == ResponseFormat.VERBOSE_JSON: return TranscriptionVerboseJsonResponse.from_segments( segments, transcription_info ) else: return StreamingResponse(segment_responses(), media_type="text/event-stream") async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: try: while True: bytes_ = await asyncio.wait_for( ws.receive_bytes(), timeout=config.max_no_data_seconds ) logger.debug(f"Received {len(bytes_)} bytes of audio data") audio_samples = audio_samples_from_file(BytesIO(bytes_)) audio_stream.extend(audio_samples) if audio_stream.duration - config.inactivity_window_seconds >= 0: audio = audio_stream.after( audio_stream.duration - config.inactivity_window_seconds ) vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) # NOTE: This is a synchronous operation that runs every time new data is received. # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. timestamps = get_speech_timestamps(audio.data, vad_opts) if len(timestamps) == 0: logger.info( f"No speech detected in the last {config.inactivity_window_seconds} seconds." ) break elif ( # last speech end time config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND >= config.max_inactivity_seconds ): logger.info( f"Not enough speech in the last {config.inactivity_window_seconds} seconds." ) break except asyncio.TimeoutError: logger.info( f"No data received in {config.max_no_data_seconds} seconds. Closing the connection." ) except WebSocketDisconnect as e: logger.info(f"Client disconnected: {e}") audio_stream.close() @app.websocket("/v1/audio/transcriptions") async def transcribe_stream( ws: WebSocket, model: Annotated[Model, Query()] = config.whisper.model, language: Annotated[Language | None, Query()] = config.default_language, prompt: Annotated[str | None, Query()] = None, response_format: Annotated[ ResponseFormat, Query() ] = config.default_response_format, temperature: Annotated[float, Query()] = 0.0, ) -> None: await ws.accept() transcribe_opts = { "language": language, "initial_prompt": prompt, "temperature": temperature, "vad_filter": True, "condition_on_previous_text": False, } whisper = load_model(model) asr = FasterWhisperASR(whisper, **transcribe_opts) audio_stream = AudioStream() async with asyncio.TaskGroup() as tg: tg.create_task(audio_receiver(ws, audio_stream)) async for transcription in audio_transcriber(asr, audio_stream): logger.debug(f"Sending transcription: {transcription.text}") if ws.client_state == WebSocketState.DISCONNECTED: break if response_format == ResponseFormat.TEXT: await ws.send_text(transcription.text) elif response_format == ResponseFormat.JSON: await ws.send_json( TranscriptionJsonResponse.from_transcription( transcription ).model_dump() ) elif response_format == ResponseFormat.VERBOSE_JSON: await ws.send_json( TranscriptionVerboseJsonResponse.from_transcription( transcription ).model_dump() ) if not ws.client_state == WebSocketState.DISCONNECTED: logger.info("Closing the connection.") await ws.close()