Fedir Zadniprovskyi commited on
Commit
2b42466
·
1 Parent(s): 3344380

chore: use async handlers in gradio

Browse files
src/faster_whisper_server/gradio_app.py CHANGED
@@ -1,10 +1,10 @@
1
- from collections.abc import Generator
2
  from pathlib import Path
3
 
4
  import gradio as gr
5
  import httpx
6
- from httpx_sse import connect_sse
7
- from openai import OpenAI
8
 
9
  from faster_whisper_server.config import Config, Task
10
  from faster_whisper_server.hf_utils import PiperModel
@@ -25,13 +25,19 @@ TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
25
 
26
  def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
27
  base_url = f"http://{config.host}:{config.port}"
28
- http_client = httpx.Client(base_url=base_url, timeout=TIMEOUT)
29
- openai_client = OpenAI(base_url=f"{base_url}/v1", api_key="cant-be-empty")
30
-
31
- # TODO: make async
32
- def whisper_handler(
 
 
 
 
 
 
33
  file_path: str, model: str, task: Task, temperature: float, stream: bool
34
- ) -> Generator[str, None, None]:
35
  if task == Task.TRANSCRIBE:
36
  endpoint = TRANSCRIPTION_ENDPOINT
37
  elif task == Task.TRANSLATE:
@@ -39,15 +45,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
39
 
40
  if stream:
41
  previous_transcription = ""
42
- for transcription in streaming_audio_task(file_path, endpoint, temperature, model):
43
  previous_transcription += transcription
44
  yield previous_transcription
45
  else:
46
- yield audio_task(file_path, endpoint, temperature, model)
47
 
48
- def audio_task(file_path: str, endpoint: str, temperature: float, model: str) -> str:
49
- with Path(file_path).open("rb") as file:
50
- response = http_client.post(
51
  endpoint,
52
  files={"file": file},
53
  data={
@@ -60,10 +66,10 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
60
  response.raise_for_status()
61
  return response.text
62
 
63
- def streaming_audio_task(
64
  file_path: str, endpoint: str, temperature: float, model: str
65
- ) -> Generator[str, None, None]:
66
- with Path(file_path).open("rb") as file:
67
  kwargs = {
68
  "files": {"file": file},
69
  "data": {
@@ -73,12 +79,12 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
73
  "stream": True,
74
  },
75
  }
76
- with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
77
- for event in event_source.iter_sse():
78
  yield event.data
79
 
80
- def update_whisper_model_dropdown() -> gr.Dropdown:
81
- models = openai_client.models.list().data
82
  model_names: list[str] = [model.id for model in models]
83
  assert config.whisper.model in model_names
84
  recommended_models = {model for model in model_names if model.startswith("Systran")}
@@ -90,14 +96,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
90
  value=config.whisper.model,
91
  )
92
 
93
- def update_piper_voices_dropdown() -> gr.Dropdown:
94
- res = http_client.get("/v1/audio/speech/voices").raise_for_status()
95
  piper_models = [PiperModel.model_validate(x) for x in res.json()]
96
  return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
97
 
98
- # TODO: make async
99
- def handle_audio_speech(text: str, voice: str, response_format: str, speed: float, sample_rate: int | None) -> Path:
100
- res = openai_client.audio.speech.create(
 
101
  input=text,
102
  model="piper",
103
  voice=voice, # pyright: ignore[reportArgumentType]
@@ -107,7 +114,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
107
  )
108
  audio_bytes = res.response.read()
109
  file_path = Path(f"audio.{response_format}")
110
- with file_path.open("wb") as file:
111
  file.write(audio_bytes)
112
  return file_path
113
 
 
1
+ from collections.abc import AsyncGenerator
2
  from pathlib import Path
3
 
4
  import gradio as gr
5
  import httpx
6
+ from httpx_sse import aconnect_sse
7
+ from openai import AsyncOpenAI
8
 
9
  from faster_whisper_server.config import Config, Task
10
  from faster_whisper_server.hf_utils import PiperModel
 
25
 
26
  def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
27
  base_url = f"http://{config.host}:{config.port}"
28
+ # TODO: test that auth works
29
+ http_client = httpx.AsyncClient(
30
+ base_url=base_url,
31
+ timeout=TIMEOUT,
32
+ headers={"Authorization": f"Bearer {config.api_key}"} if config.api_key else {},
33
+ )
34
+ openai_client = AsyncOpenAI(
35
+ base_url=f"{base_url}/v1", api_key=config.api_key if config.api_key else "cant-be-empty"
36
+ )
37
+
38
+ async def whisper_handler(
39
  file_path: str, model: str, task: Task, temperature: float, stream: bool
40
+ ) -> AsyncGenerator[str, None]:
41
  if task == Task.TRANSCRIBE:
42
  endpoint = TRANSCRIPTION_ENDPOINT
43
  elif task == Task.TRANSLATE:
 
45
 
46
  if stream:
47
  previous_transcription = ""
48
+ async for transcription in streaming_audio_task(file_path, endpoint, temperature, model):
49
  previous_transcription += transcription
50
  yield previous_transcription
51
  else:
52
+ yield await audio_task(file_path, endpoint, temperature, model)
53
 
54
+ async def audio_task(file_path: str, endpoint: str, temperature: float, model: str) -> str:
55
+ with Path(file_path).open("rb") as file: # noqa: ASYNC230
56
+ response = await http_client.post(
57
  endpoint,
58
  files={"file": file},
59
  data={
 
66
  response.raise_for_status()
67
  return response.text
68
 
69
+ async def streaming_audio_task(
70
  file_path: str, endpoint: str, temperature: float, model: str
71
+ ) -> AsyncGenerator[str, None]:
72
+ with Path(file_path).open("rb") as file: # noqa: ASYNC230
73
  kwargs = {
74
  "files": {"file": file},
75
  "data": {
 
79
  "stream": True,
80
  },
81
  }
82
+ async with aconnect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
83
+ async for event in event_source.aiter_sse():
84
  yield event.data
85
 
86
+ async def update_whisper_model_dropdown() -> gr.Dropdown:
87
+ models = (await openai_client.models.list()).data
88
  model_names: list[str] = [model.id for model in models]
89
  assert config.whisper.model in model_names
90
  recommended_models = {model for model in model_names if model.startswith("Systran")}
 
96
  value=config.whisper.model,
97
  )
98
 
99
+ async def update_piper_voices_dropdown() -> gr.Dropdown:
100
+ res = (await http_client.get("/v1/audio/speech/voices")).raise_for_status()
101
  piper_models = [PiperModel.model_validate(x) for x in res.json()]
102
  return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
103
 
104
+ async def handle_audio_speech(
105
+ text: str, voice: str, response_format: str, speed: float, sample_rate: int | None
106
+ ) -> Path:
107
+ res = await openai_client.audio.speech.create(
108
  input=text,
109
  model="piper",
110
  voice=voice, # pyright: ignore[reportArgumentType]
 
114
  )
115
  audio_bytes = res.response.read()
116
  file_path = Path(f"audio.{response_format}")
117
+ with file_path.open("wb") as file: # noqa: ASYNC230
118
  file.write(audio_bytes)
119
  return file_path
120