Spaces:
Configuration error
Configuration error
| import json | |
| import os | |
| import threading | |
| import time | |
| from difflib import SequenceMatcher | |
| from typing import Generator | |
| import pytest | |
| from fastapi import WebSocketDisconnect | |
| from fastapi.testclient import TestClient | |
| from starlette.testclient import WebSocketTestSession | |
| from speaches.config import BYTES_PER_SECOND | |
| from speaches.main import app | |
| from speaches.server_models import TranscriptionVerboseResponse | |
| SIMILARITY_THRESHOLD = 0.97 | |
| def client() -> Generator[TestClient, None, None]: | |
| with TestClient(app) as client: | |
| yield client | |
| def get_audio_file_paths(): | |
| file_paths = [] | |
| directory = "tests/data" | |
| for filename in reversed(os.listdir(directory)[5:6]): | |
| if filename.endswith(".raw"): | |
| file_paths.append(os.path.join(directory, filename)) | |
| return file_paths | |
| file_paths = get_audio_file_paths() | |
| def stream_audio_data( | |
| ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0 | |
| ): | |
| for i in range(0, len(data), chunk_size): | |
| ws.send_bytes(data[i : i + chunk_size]) | |
| delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed | |
| time.sleep(delay) | |
| def transcribe_audio_data( | |
| client: TestClient, data: bytes | |
| ) -> TranscriptionVerboseResponse: | |
| response = client.post( | |
| "/v1/audio/transcriptions?response_format=verbose_json", | |
| files={"file": ("audio.raw", data, "audio/raw")}, | |
| ) | |
| data = json.loads(response.json()) # TODO: figure this out | |
| return TranscriptionVerboseResponse(**data) # type: ignore | |
| def test_ws_audio_transcriptions(client: TestClient, file_path: str): | |
| with open(file_path, "rb") as file: | |
| data = file.read() | |
| streaming_transcription: TranscriptionVerboseResponse = None # type: ignore | |
| with client.websocket_connect( | |
| "/v1/audio/transcriptions?response_format=verbose_json" | |
| ) as ws: | |
| thread = threading.Thread( | |
| target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0} | |
| ) | |
| thread.start() | |
| while True: | |
| try: | |
| streaming_transcription = TranscriptionVerboseResponse( | |
| **ws.receive_json() | |
| ) | |
| except WebSocketDisconnect: | |
| break | |
| ws.close() | |
| file_transcription = transcribe_audio_data(client, data) | |
| s = SequenceMatcher( | |
| lambda x: x == " ", file_transcription.text, streaming_transcription.text | |
| ) | |
| assert ( | |
| s.ratio() > SIMILARITY_THRESHOLD | |
| ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}" | |