Spaces:
Configuration error
Configuration error
File size: 4,017 Bytes
8c12cdc 04d664a ec4d8ae f3d078e 04d664a dc4f25f 323aa51 8c12cdc dc4f25f 8c12cdc 04d664a dc4f25f 04d664a 8c12cdc 04d664a 8c12cdc 04d664a 8c12cdc dc4f25f 8c12cdc 04d664a dc4f25f 04d664a 8c12cdc 04d664a 8c12cdc 04d664a ec4d8ae 8c12cdc 04d664a dc4f25f 04d664a 8c12cdc 04d664a 8c12cdc 04d664a ec4d8ae 323aa51 04d664a 323aa51 04d664a 323aa51 04d664a 323aa51 04d664a 323aa51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import os
import anyio
from faster_whisper_server.api_models import (
CreateTranscriptionResponseJson,
CreateTranscriptionResponseVerboseJson,
)
from httpx import AsyncClient
from httpx_sse import aconnect_sse
import pytest
import srt
import webvtt
import webvtt.vtt
FILE_PATHS = ["audio.wav"] # HACK
ENDPOINTS = [
"/v1/audio/transcriptions",
"/v1/audio/translations",
]
parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
extension = os.path.splitext(file_path)[1]
async with await anyio.open_file(file_path, "rb") as f:
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "text", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
print(event)
assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
extension = os.path.splitext(file_path)[1]
async with await anyio.open_file(file_path, "rb") as f:
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "json", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
CreateTranscriptionResponseJson(**json.loads(event.data))
@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
extension = os.path.splitext(file_path)[1]
async with await anyio.open_file(file_path, "rb") as f:
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "verbose_json", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
CreateTranscriptionResponseVerboseJson(**json.loads(event.data))
@pytest.mark.asyncio()
async def test_transcription_vtt(aclient: AsyncClient) -> None:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
kwargs = {
"files": {"file": ("audio.wav", data, "audio/wav")},
"data": {"response_format": "vtt", "stream": False},
}
response = await aclient.post("/v1/audio/transcriptions", **kwargs)
assert response.status_code == 200
assert response.headers["content-type"] == "text/vtt; charset=utf-8"
text = response.text
webvtt.from_string(text)
text = text.replace("WEBVTT", "YO")
with pytest.raises(webvtt.vtt.MalformedFileError):
webvtt.from_string(text)
@pytest.mark.asyncio()
async def test_transcription_srt(aclient: AsyncClient) -> None:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
kwargs = {
"files": {"file": ("audio.wav", data, "audio/wav")},
"data": {"response_format": "srt", "stream": False},
}
response = await aclient.post("/v1/audio/transcriptions", **kwargs)
assert response.status_code == 200
assert "text/plain" in response.headers["content-type"]
text = response.text
list(srt.parse(text))
text = text.replace("1", "YO")
with pytest.raises(srt.SRTParseError):
list(srt.parse(text))
|