Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
ede9e6a
1
Parent(s):
74ecebe
tests: proper `get_config` dependency override
Browse files
src/faster_whisper_server/dependencies.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from functools import lru_cache
|
|
|
2 |
from typing import Annotated
|
3 |
|
4 |
from fastapi import Depends, HTTPException, status
|
@@ -11,7 +12,13 @@ from openai.resources.chat.completions import AsyncCompletions
|
|
11 |
from faster_whisper_server.config import Config
|
12 |
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
|
13 |
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
@lru_cache
|
16 |
def get_config() -> Config:
|
17 |
return Config()
|
@@ -22,7 +29,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
|
|
22 |
|
23 |
@lru_cache
|
24 |
def get_model_manager() -> WhisperModelManager:
|
25 |
-
config = get_config()
|
26 |
return WhisperModelManager(config.whisper)
|
27 |
|
28 |
|
@@ -31,8 +38,8 @@ ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manage
|
|
31 |
|
32 |
@lru_cache
|
33 |
def get_piper_model_manager() -> PiperModelManager:
|
34 |
-
config = get_config()
|
35 |
-
return PiperModelManager(config.whisper.ttl) # HACK
|
36 |
|
37 |
|
38 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
|
@@ -53,7 +60,7 @@ ApiKeyDependency = Depends(verify_api_key)
|
|
53 |
|
54 |
@lru_cache
|
55 |
def get_completion_client() -> AsyncCompletions:
|
56 |
-
config = get_config()
|
57 |
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
|
58 |
return oai_client.chat.completions
|
59 |
|
@@ -63,9 +70,9 @@ CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_
|
|
63 |
|
64 |
@lru_cache
|
65 |
def get_speech_client() -> AsyncSpeech:
|
66 |
-
config = get_config()
|
67 |
if config.speech_base_url is None:
|
68 |
-
# this might not work as expected if
|
69 |
from faster_whisper_server.routers.speech import (
|
70 |
router as speech_router,
|
71 |
)
|
@@ -86,7 +93,7 @@ SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]
|
|
86 |
def get_transcription_client() -> AsyncTranscriptions:
|
87 |
config = get_config()
|
88 |
if config.transcription_base_url is None:
|
89 |
-
# this might not work as expected if
|
90 |
from faster_whisper_server.routers.stt import (
|
91 |
router as stt_router,
|
92 |
)
|
|
|
1 |
from functools import lru_cache
|
2 |
+
import logging
|
3 |
from typing import Annotated
|
4 |
|
5 |
from fastapi import Depends, HTTPException, status
|
|
|
12 |
from faster_whisper_server.config import Config
|
13 |
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
|
14 |
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501
|
18 |
+
|
19 |
+
|
20 |
+
# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache
|
21 |
+
# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501
|
22 |
@lru_cache
|
23 |
def get_config() -> Config:
|
24 |
return Config()
|
|
|
29 |
|
30 |
@lru_cache
|
31 |
def get_model_manager() -> WhisperModelManager:
|
32 |
+
config = get_config()
|
33 |
return WhisperModelManager(config.whisper)
|
34 |
|
35 |
|
|
|
38 |
|
39 |
@lru_cache
|
40 |
def get_piper_model_manager() -> PiperModelManager:
|
41 |
+
config = get_config()
|
42 |
+
return PiperModelManager(config.whisper.ttl) # HACK: should have its own config
|
43 |
|
44 |
|
45 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
|
|
|
60 |
|
61 |
@lru_cache
|
62 |
def get_completion_client() -> AsyncCompletions:
|
63 |
+
config = get_config()
|
64 |
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
|
65 |
return oai_client.chat.completions
|
66 |
|
|
|
70 |
|
71 |
@lru_cache
|
72 |
def get_speech_client() -> AsyncSpeech:
|
73 |
+
config = get_config()
|
74 |
if config.speech_base_url is None:
|
75 |
+
# this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
|
76 |
from faster_whisper_server.routers.speech import (
|
77 |
router as speech_router,
|
78 |
)
|
|
|
93 |
def get_transcription_client() -> AsyncTranscriptions:
|
94 |
config = get_config()
|
95 |
if config.transcription_base_url is None:
|
96 |
+
# this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
|
97 |
from faster_whisper_server.routers.stt import (
|
98 |
router as stt_router,
|
99 |
)
|
src/faster_whisper_server/logger.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
import logging
|
2 |
|
3 |
-
from faster_whisper_server.dependencies import get_config
|
4 |
|
5 |
-
|
6 |
-
def setup_logger() -> None:
|
7 |
-
config = get_config() # HACK
|
8 |
logging.getLogger().setLevel(logging.INFO)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
-
logger.setLevel(
|
11 |
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
|
|
|
1 |
import logging
|
2 |
|
|
|
3 |
|
4 |
+
def setup_logger(log_level: str) -> None:
|
|
|
|
|
5 |
logging.getLogger().setLevel(logging.INFO)
|
6 |
logger = logging.getLogger(__name__)
|
7 |
+
logger.setLevel(log_level.upper())
|
8 |
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
|
src/faster_whisper_server/main.py
CHANGED
@@ -27,10 +27,12 @@ if TYPE_CHECKING:
|
|
27 |
|
28 |
|
29 |
def create_app() -> FastAPI:
|
30 |
-
|
31 |
-
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
|
|
|
|
34 |
if platform.machine() == "x86_64":
|
35 |
from faster_whisper_server.routers.speech import (
|
36 |
router as speech_router,
|
@@ -39,9 +41,6 @@ def create_app() -> FastAPI:
|
|
39 |
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
|
40 |
speech_router = None
|
41 |
|
42 |
-
config = get_config() # HACK
|
43 |
-
logger.debug(f"Config: {config}")
|
44 |
-
|
45 |
model_manager = get_model_manager() # HACK
|
46 |
|
47 |
@asynccontextmanager
|
|
|
27 |
|
28 |
|
29 |
def create_app() -> FastAPI:
|
30 |
+
config = get_config() # HACK
|
31 |
+
setup_logger(config.log_level)
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
+
logger.debug(f"Config: {config}")
|
35 |
+
|
36 |
if platform.machine() == "x86_64":
|
37 |
from faster_whisper_server.routers.speech import (
|
38 |
router as speech_router,
|
|
|
41 |
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
|
42 |
speech_router = None
|
43 |
|
|
|
|
|
|
|
44 |
model_manager = get_model_manager() # HACK
|
45 |
|
46 |
@asynccontextmanager
|
tests/conftest.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from collections.abc import AsyncGenerator, Generator
|
|
|
2 |
import logging
|
3 |
import os
|
|
|
4 |
|
5 |
from fastapi.testclient import TestClient
|
6 |
from httpx import ASGITransport, AsyncClient
|
@@ -8,19 +10,31 @@ from huggingface_hub import snapshot_download
|
|
8 |
from openai import AsyncOpenAI
|
9 |
import pytest
|
10 |
import pytest_asyncio
|
|
|
11 |
|
|
|
|
|
12 |
from faster_whisper_server.main import create_app
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
def pytest_configure() -> None:
|
18 |
-
for logger_name in
|
19 |
logger = logging.getLogger(logger_name)
|
20 |
logger.disabled = True
|
21 |
|
22 |
|
23 |
-
# NOTE: not being used. Keeping just in case
|
24 |
@pytest.fixture
|
25 |
def client() -> Generator[TestClient, None, None]:
|
26 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
@@ -28,10 +42,37 @@ def client() -> Generator[TestClient, None, None]:
|
|
28 |
yield client
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
@pytest_asyncio.fixture()
|
32 |
-
async def
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
yield aclient
|
36 |
|
37 |
|
@@ -43,11 +84,13 @@ def openai_client(aclient: AsyncClient) -> AsyncOpenAI:
|
|
43 |
@pytest.fixture
|
44 |
def actual_openai_client() -> AsyncOpenAI:
|
45 |
return AsyncOpenAI(
|
46 |
-
base_url
|
47 |
-
|
|
|
48 |
|
49 |
|
50 |
# TODO: remove the download after running the tests
|
|
|
51 |
@pytest.fixture(scope="session", autouse=True)
|
52 |
def download_piper_voices() -> None:
|
53 |
# Only download `voices.json` and the default voice
|
|
|
1 |
from collections.abc import AsyncGenerator, Generator
|
2 |
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
3 |
import logging
|
4 |
import os
|
5 |
+
from typing import Protocol
|
6 |
|
7 |
from fastapi.testclient import TestClient
|
8 |
from httpx import ASGITransport, AsyncClient
|
|
|
10 |
from openai import AsyncOpenAI
|
11 |
import pytest
|
12 |
import pytest_asyncio
|
13 |
+
from pytest_mock import MockerFixture
|
14 |
|
15 |
+
from faster_whisper_server.config import Config, WhisperConfig
|
16 |
+
from faster_whisper_server.dependencies import get_config
|
17 |
from faster_whisper_server.main import create_app
|
18 |
|
19 |
+
DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
|
20 |
+
OPENAI_BASE_URL = "https://api.openai.com/v1"
|
21 |
+
DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
|
22 |
+
# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501
|
23 |
+
DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
|
24 |
+
DEFAULT_CONFIG = Config(
|
25 |
+
whisper=DEFAULT_WHISPER_CONFIG,
|
26 |
+
# disable the UI as it slightly increases the app startup time due to the imports it's doing
|
27 |
+
enable_ui=False,
|
28 |
+
)
|
29 |
|
30 |
|
31 |
def pytest_configure() -> None:
|
32 |
+
for logger_name in DISABLE_LOGGERS:
|
33 |
logger = logging.getLogger(logger_name)
|
34 |
logger.disabled = True
|
35 |
|
36 |
|
37 |
+
# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
|
38 |
@pytest.fixture
|
39 |
def client() -> Generator[TestClient, None, None]:
|
40 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
|
|
42 |
yield client
|
43 |
|
44 |
|
45 |
+
# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional
|
46 |
+
class AclientFactory(Protocol):
|
47 |
+
def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ...
|
48 |
+
|
49 |
+
|
50 |
@pytest_asyncio.fixture()
|
51 |
+
async def aclient_factory(mocker: MockerFixture) -> AclientFactory:
|
52 |
+
"""Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration."""
|
53 |
+
|
54 |
+
@asynccontextmanager
|
55 |
+
async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]:
|
56 |
+
# NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501
|
57 |
+
mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config)
|
58 |
+
mocker.patch("faster_whisper_server.main.get_config", return_value=config)
|
59 |
+
# NOTE: I couldn't get the following to work but it shouldn't matter
|
60 |
+
# mocker.patch(
|
61 |
+
# "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config
|
62 |
+
# )
|
63 |
+
|
64 |
+
app = create_app()
|
65 |
+
# https://fastapi.tiangolo.com/advanced/testing-dependencies/
|
66 |
+
app.dependency_overrides[get_config] = lambda: config
|
67 |
+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
|
68 |
+
yield aclient
|
69 |
+
|
70 |
+
return inner
|
71 |
+
|
72 |
+
|
73 |
+
@pytest_asyncio.fixture()
|
74 |
+
async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]:
|
75 |
+
async with aclient_factory() as aclient:
|
76 |
yield aclient
|
77 |
|
78 |
|
|
|
84 |
@pytest.fixture
|
85 |
def actual_openai_client() -> AsyncOpenAI:
|
86 |
return AsyncOpenAI(
|
87 |
+
# `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value
|
88 |
+
base_url=OPENAI_BASE_URL
|
89 |
+
)
|
90 |
|
91 |
|
92 |
# TODO: remove the download after running the tests
|
93 |
+
# TODO: do not download when not needed
|
94 |
@pytest.fixture(scope="session", autouse=True)
|
95 |
def download_piper_voices() -> None:
|
96 |
# Only download `voices.json` and the default voice
|
tests/model_manager_test.py
CHANGED
@@ -1,23 +1,22 @@
|
|
1 |
import asyncio
|
2 |
-
import os
|
3 |
|
4 |
import anyio
|
5 |
-
from httpx import ASGITransport, AsyncClient
|
6 |
import pytest
|
7 |
|
8 |
-
from faster_whisper_server.
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
@pytest.mark.asyncio
|
12 |
-
async def test_model_unloaded_after_ttl() -> None:
|
13 |
ttl = 5
|
14 |
-
model =
|
15 |
-
|
16 |
-
os.environ["ENABLE_UI"] = "false"
|
17 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
18 |
res = (await aclient.get("/api/ps")).json()
|
19 |
assert len(res["models"]) == 0
|
20 |
-
await aclient.post(f"/api/ps/{
|
21 |
res = (await aclient.get("/api/ps")).json()
|
22 |
assert len(res["models"]) == 1
|
23 |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
|
@@ -26,13 +25,11 @@ async def test_model_unloaded_after_ttl() -> None:
|
|
26 |
|
27 |
|
28 |
@pytest.mark.asyncio
|
29 |
-
async def test_ttl_resets_after_usage() -> None:
|
30 |
ttl = 5
|
31 |
-
model =
|
32 |
-
|
33 |
-
|
34 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
35 |
-
await aclient.post(f"/api/ps/{model}")
|
36 |
res = (await aclient.get("/api/ps")).json()
|
37 |
assert len(res["models"]) == 1
|
38 |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
|
@@ -43,7 +40,9 @@ async def test_ttl_resets_after_usage() -> None:
|
|
43 |
data = await f.read()
|
44 |
res = (
|
45 |
await aclient.post(
|
46 |
-
"/v1/audio/transcriptions",
|
|
|
|
|
47 |
)
|
48 |
).json()
|
49 |
res = (await aclient.get("/api/ps")).json()
|
@@ -60,28 +59,28 @@ async def test_ttl_resets_after_usage() -> None:
|
|
60 |
# this just ensures the model can be loaded again after being unloaded
|
61 |
res = (
|
62 |
await aclient.post(
|
63 |
-
"/v1/audio/transcriptions",
|
|
|
|
|
64 |
)
|
65 |
).json()
|
66 |
|
67 |
|
68 |
@pytest.mark.asyncio
|
69 |
-
async def test_model_cant_be_unloaded_when_used() -> None:
|
70 |
ttl = 0
|
71 |
-
model =
|
72 |
-
|
73 |
-
os.environ["ENABLE_UI"] = "false"
|
74 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
75 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
76 |
data = await f.read()
|
77 |
|
78 |
task = asyncio.create_task(
|
79 |
aclient.post(
|
80 |
-
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model":
|
81 |
)
|
82 |
)
|
83 |
await asyncio.sleep(0.1) # wait for the server to start processing the request
|
84 |
-
res = await aclient.delete(f"/api/ps/{
|
85 |
assert res.status_code == 409
|
86 |
|
87 |
await task
|
@@ -90,27 +89,23 @@ async def test_model_cant_be_unloaded_when_used() -> None:
|
|
90 |
|
91 |
|
92 |
@pytest.mark.asyncio
|
93 |
-
async def test_model_cant_be_loaded_twice() -> None:
|
94 |
ttl = -1
|
95 |
-
model =
|
96 |
-
|
97 |
-
|
98 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
99 |
-
res = await aclient.post(f"/api/ps/{model}")
|
100 |
assert res.status_code == 201
|
101 |
-
res = await aclient.post(f"/api/ps/{
|
102 |
assert res.status_code == 409
|
103 |
res = (await aclient.get("/api/ps")).json()
|
104 |
assert len(res["models"]) == 1
|
105 |
|
106 |
|
107 |
@pytest.mark.asyncio
|
108 |
-
async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
|
109 |
ttl = 0
|
110 |
-
|
111 |
-
|
112 |
-
os.environ["ENABLE_UI"] = "false"
|
113 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
114 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
115 |
data = await f.read()
|
116 |
res = await aclient.post(
|
|
|
1 |
import asyncio
|
|
|
2 |
|
3 |
import anyio
|
|
|
4 |
import pytest
|
5 |
|
6 |
+
from faster_whisper_server.config import Config, WhisperConfig
|
7 |
+
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
|
8 |
+
|
9 |
+
MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable
|
10 |
|
11 |
|
12 |
@pytest.mark.asyncio
|
13 |
+
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
|
14 |
ttl = 5
|
15 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
16 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
17 |
res = (await aclient.get("/api/ps")).json()
|
18 |
assert len(res["models"]) == 0
|
19 |
+
await aclient.post(f"/api/ps/{MODEL}")
|
20 |
res = (await aclient.get("/api/ps")).json()
|
21 |
assert len(res["models"]) == 1
|
22 |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
|
|
|
25 |
|
26 |
|
27 |
@pytest.mark.asyncio
|
28 |
+
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
|
29 |
ttl = 5
|
30 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
31 |
+
async with aclient_factory(config) as aclient:
|
32 |
+
await aclient.post(f"/api/ps/{MODEL}")
|
|
|
|
|
33 |
res = (await aclient.get("/api/ps")).json()
|
34 |
assert len(res["models"]) == 1
|
35 |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
|
|
|
40 |
data = await f.read()
|
41 |
res = (
|
42 |
await aclient.post(
|
43 |
+
"/v1/audio/transcriptions",
|
44 |
+
files={"file": ("audio.wav", data, "audio/wav")},
|
45 |
+
data={"model": MODEL},
|
46 |
)
|
47 |
).json()
|
48 |
res = (await aclient.get("/api/ps")).json()
|
|
|
59 |
# this just ensures the model can be loaded again after being unloaded
|
60 |
res = (
|
61 |
await aclient.post(
|
62 |
+
"/v1/audio/transcriptions",
|
63 |
+
files={"file": ("audio.wav", data, "audio/wav")},
|
64 |
+
data={"model": MODEL},
|
65 |
)
|
66 |
).json()
|
67 |
|
68 |
|
69 |
@pytest.mark.asyncio
|
70 |
+
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
|
71 |
ttl = 0
|
72 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
73 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
74 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
75 |
data = await f.read()
|
76 |
|
77 |
task = asyncio.create_task(
|
78 |
aclient.post(
|
79 |
+
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
|
80 |
)
|
81 |
)
|
82 |
await asyncio.sleep(0.1) # wait for the server to start processing the request
|
83 |
+
res = await aclient.delete(f"/api/ps/{MODEL}")
|
84 |
assert res.status_code == 409
|
85 |
|
86 |
await task
|
|
|
89 |
|
90 |
|
91 |
@pytest.mark.asyncio
|
92 |
+
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
|
93 |
ttl = -1
|
94 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
95 |
+
async with aclient_factory(config) as aclient:
|
96 |
+
res = await aclient.post(f"/api/ps/{MODEL}")
|
|
|
|
|
97 |
assert res.status_code == 201
|
98 |
+
res = await aclient.post(f"/api/ps/{MODEL}")
|
99 |
assert res.status_code == 409
|
100 |
res = (await aclient.get("/api/ps")).json()
|
101 |
assert len(res["models"]) == 1
|
102 |
|
103 |
|
104 |
@pytest.mark.asyncio
|
105 |
+
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
|
106 |
ttl = 0
|
107 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
108 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
109 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
110 |
data = await f.read()
|
111 |
res = await aclient.post(
|