File size: 2,592 Bytes
313814b
 
3a14175
bf48682
7cc3853
8f3dcc9
313814b
d0feed8
 
 
e0e6882
313814b
ba81a8e
 
 
8f3dcc9
 
ba81a8e
7785332
 
ba81a8e
8f3dcc9
d0feed8
313814b
dc4f25f
8f3dcc9
dc4f25f
7785332
 
 
 
 
 
 
 
 
 
 
 
 
aada575
bf48682
ede9e6a
 
bf48682
 
ede9e6a
 
7cc3853
ba81a8e
7cc3853
 
 
 
 
 
bf48682
3a14175
bf48682
 
 
 
 
3a14175
dcbab06
 
307e23f
dcbab06
7785332
3a14175
bf48682
7785332
bf48682
7cc3853
 
313814b
bf48682
 
 
 
 
 
 
 
8f3dcc9
bf48682
 
9d4a9a2
ba81a8e
4e64465
bf48682
4e64465
bf48682
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

from contextlib import asynccontextmanager
import logging
import platform
from typing import TYPE_CHECKING

from fastapi import (
    FastAPI,
)
from fastapi.middleware.cors import CORSMiddleware

from speaches.dependencies import ApiKeyDependency, get_config, get_model_manager
from speaches.logger import setup_logger
from speaches.routers.misc import (
    router as misc_router,
)
from speaches.routers.models import (
    router as models_router,
)
from speaches.routers.stt import (
    router as stt_router,
)

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator

# https://swagger.io/docs/specification/v3_0/grouping-operations-with-tags/
# https://fastapi.tiangolo.com/tutorial/metadata/#metadata-for-tags
TAGS_METADATA = [
    {"name": "automatic-speech-recognition"},
    {"name": "speech-to-text"},
    {"name": "models"},
    {"name": "diagnostic"},
    {
        "name": "experimental",
        "description": "Not meant for public use yet. May change or be removed at any time.",
    },
]


def create_app() -> FastAPI:
    config = get_config()  # HACK
    setup_logger(config.log_level)
    logger = logging.getLogger(__name__)

    logger.debug(f"Config: {config}")

    if platform.machine() == "x86_64":
        from speaches.routers.speech import (
            router as speech_router,
        )
    else:
        logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
        speech_router = None

    model_manager = get_model_manager()  # HACK

    @asynccontextmanager
    async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
        for model_name in config.preload_models:
            model_manager.load_model(model_name)
        yield

    dependencies = []
    if config.api_key is not None:
        dependencies.append(ApiKeyDependency)

    app = FastAPI(lifespan=lifespan, dependencies=dependencies, openapi_tags=TAGS_METADATA)

    app.include_router(stt_router)
    app.include_router(models_router)
    app.include_router(misc_router)
    if speech_router is not None:
        app.include_router(speech_router)

    if config.allow_origins is not None:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=config.allow_origins,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

    if config.enable_ui:
        import gradio as gr

        from speaches.gradio_app import create_gradio_demo

        app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")

    return app