Fedir Zadniprovskyi commited on
Commit
f58fddb
·
1 Parent(s): 2b42466

chore: get base url from the request url in gradio

Browse files
src/faster_whisper_server/gradio_app.py CHANGED
@@ -22,22 +22,34 @@ TRANSLATION_ENDPOINT = "/v1/audio/translations"
22
  TIMEOUT_SECONDS = 180
23
  TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
24
 
 
25
 
26
- def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
27
- base_url = f"http://{config.host}:{config.port}"
28
- # TODO: test that auth works
29
- http_client = httpx.AsyncClient(
 
 
 
 
 
 
30
  base_url=base_url,
31
  timeout=TIMEOUT,
32
- headers={"Authorization": f"Bearer {config.api_key}"} if config.api_key else {},
33
- )
34
- openai_client = AsyncOpenAI(
35
- base_url=f"{base_url}/v1", api_key=config.api_key if config.api_key else "cant-be-empty"
36
  )
37
 
 
 
 
 
 
 
 
38
  async def whisper_handler(
39
- file_path: str, model: str, task: Task, temperature: float, stream: bool
40
  ) -> AsyncGenerator[str, None]:
 
41
  if task == Task.TRANSCRIBE:
42
  endpoint = TRANSCRIPTION_ENDPOINT
43
  elif task == Task.TRANSLATE:
@@ -45,13 +57,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
45
 
46
  if stream:
47
  previous_transcription = ""
48
- async for transcription in streaming_audio_task(file_path, endpoint, temperature, model):
49
  previous_transcription += transcription
50
  yield previous_transcription
51
  else:
52
- yield await audio_task(file_path, endpoint, temperature, model)
53
 
54
- async def audio_task(file_path: str, endpoint: str, temperature: float, model: str) -> str:
 
 
55
  with Path(file_path).open("rb") as file: # noqa: ASYNC230
56
  response = await http_client.post(
57
  endpoint,
@@ -67,7 +81,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
67
  return response.text
68
 
69
  async def streaming_audio_task(
70
- file_path: str, endpoint: str, temperature: float, model: str
71
  ) -> AsyncGenerator[str, None]:
72
  with Path(file_path).open("rb") as file: # noqa: ASYNC230
73
  kwargs = {
@@ -83,7 +97,8 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
83
  async for event in event_source.aiter_sse():
84
  yield event.data
85
 
86
- async def update_whisper_model_dropdown() -> gr.Dropdown:
 
87
  models = (await openai_client.models.list()).data
88
  model_names: list[str] = [model.id for model in models]
89
  assert config.whisper.model in model_names
@@ -96,14 +111,16 @@ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
96
  value=config.whisper.model,
97
  )
98
 
99
- async def update_piper_voices_dropdown() -> gr.Dropdown:
 
100
  res = (await http_client.get("/v1/audio/speech/voices")).raise_for_status()
101
  piper_models = [PiperModel.model_validate(x) for x in res.json()]
102
  return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
103
 
104
  async def handle_audio_speech(
105
- text: str, voice: str, response_format: str, speed: float, sample_rate: int | None
106
  ) -> Path:
 
107
  res = await openai_client.audio.speech.create(
108
  input=text,
109
  model="piper",
 
22
  TIMEOUT_SECONDS = 180
23
  TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
24
 
25
+ # NOTE: `gr.Request` seems to be passed in as the last positional (not keyword) argument
26
 
27
+
28
+ def base_url_from_gradio_req(request: gr.Request) -> str:
29
+ # NOTE: `request.request.url` seems to always have a path of "/gradio_api/queue/join"
30
+ assert request.request is not None
31
+ return f"{request.request.url.scheme}://{request.request.url.netloc}"
32
+
33
+
34
+ def http_client_from_gradio_req(request: gr.Request, config: Config) -> httpx.AsyncClient:
35
+ base_url = base_url_from_gradio_req(request)
36
+ return httpx.AsyncClient(
37
  base_url=base_url,
38
  timeout=TIMEOUT,
39
+ headers={"Authorization": f"Bearer {config.api_key}"} if config.api_key else None,
 
 
 
40
  )
41
 
42
+
43
+ def openai_client_from_gradio_req(request: gr.Request, config: Config) -> AsyncOpenAI:
44
+ base_url = base_url_from_gradio_req(request)
45
+ return AsyncOpenAI(base_url=f"{base_url}/v1", api_key=config.api_key if config.api_key else "cant-be-empty")
46
+
47
+
48
+ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
49
  async def whisper_handler(
50
+ file_path: str, model: str, task: Task, temperature: float, stream: bool, request: gr.Request
51
  ) -> AsyncGenerator[str, None]:
52
+ http_client = http_client_from_gradio_req(request, config)
53
  if task == Task.TRANSCRIBE:
54
  endpoint = TRANSCRIPTION_ENDPOINT
55
  elif task == Task.TRANSLATE:
 
57
 
58
  if stream:
59
  previous_transcription = ""
60
+ async for transcription in streaming_audio_task(http_client, file_path, endpoint, temperature, model):
61
  previous_transcription += transcription
62
  yield previous_transcription
63
  else:
64
+ yield await audio_task(http_client, file_path, endpoint, temperature, model)
65
 
66
+ async def audio_task(
67
+ http_client: httpx.AsyncClient, file_path: str, endpoint: str, temperature: float, model: str
68
+ ) -> str:
69
  with Path(file_path).open("rb") as file: # noqa: ASYNC230
70
  response = await http_client.post(
71
  endpoint,
 
81
  return response.text
82
 
83
  async def streaming_audio_task(
84
+ http_client: httpx.AsyncClient, file_path: str, endpoint: str, temperature: float, model: str
85
  ) -> AsyncGenerator[str, None]:
86
  with Path(file_path).open("rb") as file: # noqa: ASYNC230
87
  kwargs = {
 
97
  async for event in event_source.aiter_sse():
98
  yield event.data
99
 
100
+ async def update_whisper_model_dropdown(request: gr.Request) -> gr.Dropdown:
101
+ openai_client = openai_client_from_gradio_req(request, config)
102
  models = (await openai_client.models.list()).data
103
  model_names: list[str] = [model.id for model in models]
104
  assert config.whisper.model in model_names
 
111
  value=config.whisper.model,
112
  )
113
 
114
+ async def update_piper_voices_dropdown(request: gr.Request) -> gr.Dropdown:
115
+ http_client = http_client_from_gradio_req(request, config)
116
  res = (await http_client.get("/v1/audio/speech/voices")).raise_for_status()
117
  piper_models = [PiperModel.model_validate(x) for x in res.json()]
118
  return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
119
 
120
  async def handle_audio_speech(
121
+ text: str, voice: str, response_format: str, speed: float, sample_rate: int | None, request: gr.Request
122
  ) -> Path:
123
+ openai_client = openai_client_from_gradio_req(request, config)
124
  res = await openai_client.audio.speech.create(
125
  input=text,
126
  model="piper",