File size: 3,582 Bytes
bf48682
 
 
dcbab06
 
7f77e4b
 
 
 
bf48682
 
7cc3853
bf48682
 
 
 
 
 
 
 
 
 
 
7cc3853
bf48682
7cc3853
bf48682
 
7cc3853
 
 
 
 
 
 
 
 
 
dcbab06
 
 
 
 
 
 
 
 
 
 
 
 
7f77e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from functools import lru_cache
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from httpx import ASGITransport, AsyncClient
from openai import AsyncOpenAI
from openai.resources.audio import AsyncSpeech, AsyncTranscriptions
from openai.resources.chat.completions import AsyncCompletions

from faster_whisper_server.config import Config
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager


@lru_cache
def get_config() -> Config:
    return Config()


ConfigDependency = Annotated[Config, Depends(get_config)]


@lru_cache
def get_model_manager() -> WhisperModelManager:
    config = get_config()  # HACK
    return WhisperModelManager(config.whisper)


ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)]


@lru_cache
def get_piper_model_manager() -> PiperModelManager:
    config = get_config()  # HACK
    return PiperModelManager(config.whisper.ttl)  # HACK


PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]


security = HTTPBearer()


async def verify_api_key(
    config: ConfigDependency, credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]
) -> None:
    if credentials.credentials != config.api_key:
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)


ApiKeyDependency = Depends(verify_api_key)


@lru_cache
def get_completion_client() -> AsyncCompletions:
    config = get_config()  # HACK
    oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
    return oai_client.chat.completions


CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_client)]


@lru_cache
def get_speech_client() -> AsyncSpeech:
    config = get_config()  # HACK
    if config.speech_base_url is None:
        # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
        from faster_whisper_server.routers.speech import (
            router as speech_router,
        )

        http_client = AsyncClient(
            transport=ASGITransport(speech_router), base_url="http://test/v1"
        )  # NOTE: "test" can be replaced with any other value
        oai_client = AsyncOpenAI(http_client=http_client, api_key=config.speech_api_key)
    else:
        oai_client = AsyncOpenAI(base_url=config.speech_base_url, api_key=config.speech_api_key)
    return oai_client.audio.speech


SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]


@lru_cache
def get_transcription_client() -> AsyncTranscriptions:
    config = get_config()
    if config.transcription_base_url is None:
        # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
        from faster_whisper_server.routers.stt import (
            router as stt_router,
        )

        http_client = AsyncClient(
            transport=ASGITransport(stt_router), base_url="http://test/v1"
        )  # NOTE: "test" can be replaced with any other value

        oai_client = AsyncOpenAI(http_client=http_client, api_key=config.transcription_api_key)
    else:
        oai_client = AsyncOpenAI(base_url=config.transcription_base_url, api_key=config.transcription_api_key)
    return oai_client.audio.transcriptions


TranscriptionClientDependency = Annotated[AsyncTranscriptions, Depends(get_transcription_client)]