File size: 4,017 Bytes
8c12cdc
 
 
04d664a
ec4d8ae
 
 
f3d078e
04d664a
 
dc4f25f
323aa51
 
 
8c12cdc
 
 
 
 
 
 
 
dc4f25f
8c12cdc
 
04d664a
dc4f25f
04d664a
8c12cdc
04d664a
 
8c12cdc
 
 
 
04d664a
 
8c12cdc
dc4f25f
8c12cdc
 
04d664a
dc4f25f
04d664a
8c12cdc
04d664a
 
8c12cdc
 
 
 
04d664a
 
ec4d8ae
8c12cdc
 
04d664a
dc4f25f
04d664a
8c12cdc
04d664a
 
8c12cdc
 
 
 
04d664a
 
ec4d8ae
323aa51
 
04d664a
 
 
 
323aa51
 
 
 
04d664a
323aa51
 
 
 
 
 
 
 
 
04d664a
 
 
 
323aa51
 
 
 
04d664a
323aa51
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import os

import anyio
from faster_whisper_server.api_models import (
    CreateTranscriptionResponseJson,
    CreateTranscriptionResponseVerboseJson,
)
from httpx import AsyncClient
from httpx_sse import aconnect_sse
import pytest
import srt
import webvtt
import webvtt.vtt

FILE_PATHS = ["audio.wav"]  # HACK
ENDPOINTS = [
    "/v1/audio/transcriptions",
    "/v1/audio/translations",
]


parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]


@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = os.path.splitext(file_path)[1]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "text", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            print(event)
            assert len(event.data) > 1  # HACK: 1 because of the space character that's always prepended


@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = os.path.splitext(file_path)[1]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "json", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            CreateTranscriptionResponseJson(**json.loads(event.data))


@pytest.mark.asyncio()
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None:
    extension = os.path.splitext(file_path)[1]
    async with await anyio.open_file(file_path, "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
        "data": {"response_format": "verbose_json", "stream": True},
    }
    async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
        async for event in event_source.aiter_sse():
            CreateTranscriptionResponseVerboseJson(**json.loads(event.data))


@pytest.mark.asyncio()
async def test_transcription_vtt(aclient: AsyncClient) -> None:
    async with await anyio.open_file("audio.wav", "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": ("audio.wav", data, "audio/wav")},
        "data": {"response_format": "vtt", "stream": False},
    }
    response = await aclient.post("/v1/audio/transcriptions", **kwargs)
    assert response.status_code == 200
    assert response.headers["content-type"] == "text/vtt; charset=utf-8"
    text = response.text
    webvtt.from_string(text)
    text = text.replace("WEBVTT", "YO")
    with pytest.raises(webvtt.vtt.MalformedFileError):
        webvtt.from_string(text)


@pytest.mark.asyncio()
async def test_transcription_srt(aclient: AsyncClient) -> None:
    async with await anyio.open_file("audio.wav", "rb") as f:
        data = await f.read()
    kwargs = {
        "files": {"file": ("audio.wav", data, "audio/wav")},
        "data": {"response_format": "srt", "stream": False},
    }
    response = await aclient.post("/v1/audio/transcriptions", **kwargs)
    assert response.status_code == 200
    assert "text/plain" in response.headers["content-type"]

    text = response.text
    list(srt.parse(text))
    text = text.replace("1", "YO")
    with pytest.raises(srt.SRTParseError):
        list(srt.parse(text))