|
from http import HTTPStatus |
|
|
|
import numpy as np |
|
from kui.asgi import HTTPException |
|
|
|
from fish_speech.inference_engine import TTSInferenceEngine |
|
from fish_speech.utils.schema import ServeTTSRequest |
|
|
|
AMPLITUDE = 32768 |
|
|
|
|
|
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): |
|
""" |
|
Wrapper for the inference function. |
|
Used in the API server. |
|
""" |
|
count = 0 |
|
for result in engine.inference(req): |
|
match result.code: |
|
case "header": |
|
if isinstance(result.audio, tuple): |
|
yield result.audio[1] |
|
|
|
case "error": |
|
raise HTTPException( |
|
HTTPStatus.INTERNAL_SERVER_ERROR, |
|
content=str(result.error), |
|
) |
|
|
|
case "segment": |
|
count += 1 |
|
if isinstance(result.audio, tuple): |
|
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes() |
|
|
|
case "final": |
|
count += 1 |
|
if isinstance(result.audio, tuple): |
|
yield result.audio[1] |
|
return None |
|
|
|
if count == 0: |
|
raise HTTPException( |
|
HTTPStatus.INTERNAL_SERVER_ERROR, |
|
content="No audio generated, please check the input text.", |
|
) |
|
|