Fedir Zadniprovskyi commited on
Commit
ede9e6a
·
1 Parent(s): 74ecebe

tests: proper `get_config` dependency override

Browse files
src/faster_whisper_server/dependencies.py CHANGED
@@ -1,4 +1,5 @@
1
  from functools import lru_cache
 
2
  from typing import Annotated
3
 
4
  from fastapi import Depends, HTTPException, status
@@ -11,7 +12,13 @@ from openai.resources.chat.completions import AsyncCompletions
11
  from faster_whisper_server.config import Config
12
  from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
13
 
 
14
 
 
 
 
 
 
15
  @lru_cache
16
  def get_config() -> Config:
17
  return Config()
@@ -22,7 +29,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
22
 
23
  @lru_cache
24
  def get_model_manager() -> WhisperModelManager:
25
- config = get_config() # HACK
26
  return WhisperModelManager(config.whisper)
27
 
28
 
@@ -31,8 +38,8 @@ ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manage
31
 
32
  @lru_cache
33
  def get_piper_model_manager() -> PiperModelManager:
34
- config = get_config() # HACK
35
- return PiperModelManager(config.whisper.ttl) # HACK
36
 
37
 
38
  PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
@@ -53,7 +60,7 @@ ApiKeyDependency = Depends(verify_api_key)
53
 
54
  @lru_cache
55
  def get_completion_client() -> AsyncCompletions:
56
- config = get_config() # HACK
57
  oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
58
  return oai_client.chat.completions
59
 
@@ -63,9 +70,9 @@ CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_
63
 
64
  @lru_cache
65
  def get_speech_client() -> AsyncSpeech:
66
- config = get_config() # HACK
67
  if config.speech_base_url is None:
68
- # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
69
  from faster_whisper_server.routers.speech import (
70
  router as speech_router,
71
  )
@@ -86,7 +93,7 @@ SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]
86
  def get_transcription_client() -> AsyncTranscriptions:
87
  config = get_config()
88
  if config.transcription_base_url is None:
89
- # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
90
  from faster_whisper_server.routers.stt import (
91
  router as stt_router,
92
  )
 
1
  from functools import lru_cache
2
+ import logging
3
  from typing import Annotated
4
 
5
  from fastapi import Depends, HTTPException, status
 
12
  from faster_whisper_server.config import Config
13
  from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
14
 
15
+ logger = logging.getLogger(__name__)
16
 
17
+ # NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501
18
+
19
+
20
+ # https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache
21
+ # WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501
22
  @lru_cache
23
  def get_config() -> Config:
24
  return Config()
 
29
 
30
  @lru_cache
31
  def get_model_manager() -> WhisperModelManager:
32
+ config = get_config()
33
  return WhisperModelManager(config.whisper)
34
 
35
 
 
38
 
39
  @lru_cache
40
  def get_piper_model_manager() -> PiperModelManager:
41
+ config = get_config()
42
+ return PiperModelManager(config.whisper.ttl) # HACK: should have its own config
43
 
44
 
45
  PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
 
60
 
61
  @lru_cache
62
  def get_completion_client() -> AsyncCompletions:
63
+ config = get_config()
64
  oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
65
  return oai_client.chat.completions
66
 
 
70
 
71
  @lru_cache
72
  def get_speech_client() -> AsyncSpeech:
73
+ config = get_config()
74
  if config.speech_base_url is None:
75
+ # this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
76
  from faster_whisper_server.routers.speech import (
77
  router as speech_router,
78
  )
 
93
  def get_transcription_client() -> AsyncTranscriptions:
94
  config = get_config()
95
  if config.transcription_base_url is None:
96
+ # this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
97
  from faster_whisper_server.routers.stt import (
98
  router as stt_router,
99
  )
src/faster_whisper_server/logger.py CHANGED
@@ -1,11 +1,8 @@
1
  import logging
2
 
3
- from faster_whisper_server.dependencies import get_config
4
 
5
-
6
- def setup_logger() -> None:
7
- config = get_config() # HACK
8
  logging.getLogger().setLevel(logging.INFO)
9
  logger = logging.getLogger(__name__)
10
- logger.setLevel(config.log_level.upper())
11
  logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
 
1
  import logging
2
 
 
3
 
4
+ def setup_logger(log_level: str) -> None:
 
 
5
  logging.getLogger().setLevel(logging.INFO)
6
  logger = logging.getLogger(__name__)
7
+ logger.setLevel(log_level.upper())
8
  logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
src/faster_whisper_server/main.py CHANGED
@@ -27,10 +27,12 @@ if TYPE_CHECKING:
27
 
28
 
29
  def create_app() -> FastAPI:
30
- setup_logger()
31
-
32
  logger = logging.getLogger(__name__)
33
 
 
 
34
  if platform.machine() == "x86_64":
35
  from faster_whisper_server.routers.speech import (
36
  router as speech_router,
@@ -39,9 +41,6 @@ def create_app() -> FastAPI:
39
  logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
40
  speech_router = None
41
 
42
- config = get_config() # HACK
43
- logger.debug(f"Config: {config}")
44
-
45
  model_manager = get_model_manager() # HACK
46
 
47
  @asynccontextmanager
 
27
 
28
 
29
  def create_app() -> FastAPI:
30
+ config = get_config() # HACK
31
+ setup_logger(config.log_level)
32
  logger = logging.getLogger(__name__)
33
 
34
+ logger.debug(f"Config: {config}")
35
+
36
  if platform.machine() == "x86_64":
37
  from faster_whisper_server.routers.speech import (
38
  router as speech_router,
 
41
  logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
42
  speech_router = None
43
 
 
 
 
44
  model_manager = get_model_manager() # HACK
45
 
46
  @asynccontextmanager
tests/conftest.py CHANGED
@@ -1,6 +1,8 @@
1
  from collections.abc import AsyncGenerator, Generator
 
2
  import logging
3
  import os
 
4
 
5
  from fastapi.testclient import TestClient
6
  from httpx import ASGITransport, AsyncClient
@@ -8,19 +10,31 @@ from huggingface_hub import snapshot_download
8
  from openai import AsyncOpenAI
9
  import pytest
10
  import pytest_asyncio
 
11
 
 
 
12
  from faster_whisper_server.main import create_app
13
 
14
- disable_loggers = ["multipart.multipart", "faster_whisper"]
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def pytest_configure() -> None:
18
- for logger_name in disable_loggers:
19
  logger = logging.getLogger(logger_name)
20
  logger.disabled = True
21
 
22
 
23
- # NOTE: not being used. Keeping just in case
24
  @pytest.fixture
25
  def client() -> Generator[TestClient, None, None]:
26
  os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
@@ -28,10 +42,37 @@ def client() -> Generator[TestClient, None, None]:
28
  yield client
29
 
30
 
 
 
 
 
 
31
  @pytest_asyncio.fixture()
32
- async def aclient() -> AsyncGenerator[AsyncClient, None]:
33
- os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
34
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  yield aclient
36
 
37
 
@@ -43,11 +84,13 @@ def openai_client(aclient: AsyncClient) -> AsyncOpenAI:
43
  @pytest.fixture
44
  def actual_openai_client() -> AsyncOpenAI:
45
  return AsyncOpenAI(
46
- base_url="https://api.openai.com/v1"
47
- ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
 
48
 
49
 
50
  # TODO: remove the download after running the tests
 
51
  @pytest.fixture(scope="session", autouse=True)
52
  def download_piper_voices() -> None:
53
  # Only download `voices.json` and the default voice
 
1
  from collections.abc import AsyncGenerator, Generator
2
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
3
  import logging
4
  import os
5
+ from typing import Protocol
6
 
7
  from fastapi.testclient import TestClient
8
  from httpx import ASGITransport, AsyncClient
 
10
  from openai import AsyncOpenAI
11
  import pytest
12
  import pytest_asyncio
13
+ from pytest_mock import MockerFixture
14
 
15
+ from faster_whisper_server.config import Config, WhisperConfig
16
+ from faster_whisper_server.dependencies import get_config
17
  from faster_whisper_server.main import create_app
18
 
19
+ DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
20
+ OPENAI_BASE_URL = "https://api.openai.com/v1"
21
+ DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
22
+ # TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501
23
+ DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
24
+ DEFAULT_CONFIG = Config(
25
+ whisper=DEFAULT_WHISPER_CONFIG,
26
+ # disable the UI as it slightly increases the app startup time due to the imports it's doing
27
+ enable_ui=False,
28
+ )
29
 
30
 
31
  def pytest_configure() -> None:
32
+ for logger_name in DISABLE_LOGGERS:
33
  logger = logging.getLogger(logger_name)
34
  logger.disabled = True
35
 
36
 
37
+ # NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
38
  @pytest.fixture
39
  def client() -> Generator[TestClient, None, None]:
40
  os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
 
42
  yield client
43
 
44
 
45
+ # https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional
46
+ class AclientFactory(Protocol):
47
+ def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ...
48
+
49
+
50
  @pytest_asyncio.fixture()
51
+ async def aclient_factory(mocker: MockerFixture) -> AclientFactory:
52
+ """Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration."""
53
+
54
+ @asynccontextmanager
55
+ async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]:
56
+ # NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501
57
+ mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config)
58
+ mocker.patch("faster_whisper_server.main.get_config", return_value=config)
59
+ # NOTE: I couldn't get the following to work but it shouldn't matter
60
+ # mocker.patch(
61
+ # "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config
62
+ # )
63
+
64
+ app = create_app()
65
+ # https://fastapi.tiangolo.com/advanced/testing-dependencies/
66
+ app.dependency_overrides[get_config] = lambda: config
67
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
68
+ yield aclient
69
+
70
+ return inner
71
+
72
+
73
+ @pytest_asyncio.fixture()
74
+ async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]:
75
+ async with aclient_factory() as aclient:
76
  yield aclient
77
 
78
 
 
84
  @pytest.fixture
85
  def actual_openai_client() -> AsyncOpenAI:
86
  return AsyncOpenAI(
87
+ # `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value
88
+ base_url=OPENAI_BASE_URL
89
+ )
90
 
91
 
92
  # TODO: remove the download after running the tests
93
+ # TODO: do not download when not needed
94
  @pytest.fixture(scope="session", autouse=True)
95
  def download_piper_voices() -> None:
96
  # Only download `voices.json` and the default voice
tests/model_manager_test.py CHANGED
@@ -1,23 +1,22 @@
1
  import asyncio
2
- import os
3
 
4
  import anyio
5
- from httpx import ASGITransport, AsyncClient
6
  import pytest
7
 
8
- from faster_whisper_server.main import create_app
 
 
 
9
 
10
 
11
  @pytest.mark.asyncio
12
- async def test_model_unloaded_after_ttl() -> None:
13
  ttl = 5
14
- model = "Systran/faster-whisper-tiny.en"
15
- os.environ["WHISPER__TTL"] = str(ttl)
16
- os.environ["ENABLE_UI"] = "false"
17
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
18
  res = (await aclient.get("/api/ps")).json()
19
  assert len(res["models"]) == 0
20
- await aclient.post(f"/api/ps/{model}")
21
  res = (await aclient.get("/api/ps")).json()
22
  assert len(res["models"]) == 1
23
  await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
@@ -26,13 +25,11 @@ async def test_model_unloaded_after_ttl() -> None:
26
 
27
 
28
  @pytest.mark.asyncio
29
- async def test_ttl_resets_after_usage() -> None:
30
  ttl = 5
31
- model = "Systran/faster-whisper-tiny.en"
32
- os.environ["WHISPER__TTL"] = str(ttl)
33
- os.environ["ENABLE_UI"] = "false"
34
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
35
- await aclient.post(f"/api/ps/{model}")
36
  res = (await aclient.get("/api/ps")).json()
37
  assert len(res["models"]) == 1
38
  await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
@@ -43,7 +40,9 @@ async def test_ttl_resets_after_usage() -> None:
43
  data = await f.read()
44
  res = (
45
  await aclient.post(
46
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
 
 
47
  )
48
  ).json()
49
  res = (await aclient.get("/api/ps")).json()
@@ -60,28 +59,28 @@ async def test_ttl_resets_after_usage() -> None:
60
  # this just ensures the model can be loaded again after being unloaded
61
  res = (
62
  await aclient.post(
63
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
 
 
64
  )
65
  ).json()
66
 
67
 
68
  @pytest.mark.asyncio
69
- async def test_model_cant_be_unloaded_when_used() -> None:
70
  ttl = 0
71
- model = "Systran/faster-whisper-tiny.en"
72
- os.environ["WHISPER__TTL"] = str(ttl)
73
- os.environ["ENABLE_UI"] = "false"
74
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
75
  async with await anyio.open_file("audio.wav", "rb") as f:
76
  data = await f.read()
77
 
78
  task = asyncio.create_task(
79
  aclient.post(
80
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
81
  )
82
  )
83
  await asyncio.sleep(0.1) # wait for the server to start processing the request
84
- res = await aclient.delete(f"/api/ps/{model}")
85
  assert res.status_code == 409
86
 
87
  await task
@@ -90,27 +89,23 @@ async def test_model_cant_be_unloaded_when_used() -> None:
90
 
91
 
92
  @pytest.mark.asyncio
93
- async def test_model_cant_be_loaded_twice() -> None:
94
  ttl = -1
95
- model = "Systran/faster-whisper-tiny.en"
96
- os.environ["ENABLE_UI"] = "false"
97
- os.environ["WHISPER__TTL"] = str(ttl)
98
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
99
- res = await aclient.post(f"/api/ps/{model}")
100
  assert res.status_code == 201
101
- res = await aclient.post(f"/api/ps/{model}")
102
  assert res.status_code == 409
103
  res = (await aclient.get("/api/ps")).json()
104
  assert len(res["models"]) == 1
105
 
106
 
107
  @pytest.mark.asyncio
108
- async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
109
  ttl = 0
110
- os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
111
- os.environ["WHISPER__TTL"] = str(ttl)
112
- os.environ["ENABLE_UI"] = "false"
113
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
114
  async with await anyio.open_file("audio.wav", "rb") as f:
115
  data = await f.read()
116
  res = await aclient.post(
 
1
  import asyncio
 
2
 
3
  import anyio
 
4
  import pytest
5
 
6
+ from faster_whisper_server.config import Config, WhisperConfig
7
+ from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
8
+
9
+ MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable
10
 
11
 
12
  @pytest.mark.asyncio
13
+ async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
14
  ttl = 5
15
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
16
+ async with aclient_factory(config) as aclient:
 
 
17
  res = (await aclient.get("/api/ps")).json()
18
  assert len(res["models"]) == 0
19
+ await aclient.post(f"/api/ps/{MODEL}")
20
  res = (await aclient.get("/api/ps")).json()
21
  assert len(res["models"]) == 1
22
  await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
 
25
 
26
 
27
  @pytest.mark.asyncio
28
+ async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
29
  ttl = 5
30
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
31
+ async with aclient_factory(config) as aclient:
32
+ await aclient.post(f"/api/ps/{MODEL}")
 
 
33
  res = (await aclient.get("/api/ps")).json()
34
  assert len(res["models"]) == 1
35
  await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
 
40
  data = await f.read()
41
  res = (
42
  await aclient.post(
43
+ "/v1/audio/transcriptions",
44
+ files={"file": ("audio.wav", data, "audio/wav")},
45
+ data={"model": MODEL},
46
  )
47
  ).json()
48
  res = (await aclient.get("/api/ps")).json()
 
59
  # this just ensures the model can be loaded again after being unloaded
60
  res = (
61
  await aclient.post(
62
+ "/v1/audio/transcriptions",
63
+ files={"file": ("audio.wav", data, "audio/wav")},
64
+ data={"model": MODEL},
65
  )
66
  ).json()
67
 
68
 
69
  @pytest.mark.asyncio
70
+ async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
71
  ttl = 0
72
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
73
+ async with aclient_factory(config) as aclient:
 
 
74
  async with await anyio.open_file("audio.wav", "rb") as f:
75
  data = await f.read()
76
 
77
  task = asyncio.create_task(
78
  aclient.post(
79
+ "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
80
  )
81
  )
82
  await asyncio.sleep(0.1) # wait for the server to start processing the request
83
+ res = await aclient.delete(f"/api/ps/{MODEL}")
84
  assert res.status_code == 409
85
 
86
  await task
 
89
 
90
 
91
  @pytest.mark.asyncio
92
+ async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
93
  ttl = -1
94
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
95
+ async with aclient_factory(config) as aclient:
96
+ res = await aclient.post(f"/api/ps/{MODEL}")
 
 
97
  assert res.status_code == 201
98
+ res = await aclient.post(f"/api/ps/{MODEL}")
99
  assert res.status_code == 409
100
  res = (await aclient.get("/api/ps")).json()
101
  assert len(res["models"]) == 1
102
 
103
 
104
  @pytest.mark.asyncio
105
+ async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
106
  ttl = 0
107
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
108
+ async with aclient_factory(config) as aclient:
 
 
109
  async with await anyio.open_file("audio.wav", "rb") as f:
110
  data = await f.read()
111
  res = await aclient.post(