|
from argparse import ArgumentParser |
|
from http import HTTPStatus |
|
from typing import Annotated, Any |
|
|
|
import ormsgpack |
|
from baize.datastructures import ContentType |
|
from kui.asgi import HTTPException, HttpRequest |
|
|
|
from fish_speech.inference_engine import TTSInferenceEngine |
|
from fish_speech.utils.schema import ServeTTSRequest |
|
from tools.server.inference import inference_wrapper as inference |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") |
|
parser.add_argument("--load-asr-model", action="store_true") |
|
parser.add_argument( |
|
"--llama-checkpoint-path", |
|
type=str, |
|
default="checkpoints/fish-speech-1.5", |
|
) |
|
parser.add_argument( |
|
"--decoder-checkpoint-path", |
|
type=str, |
|
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", |
|
) |
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") |
|
parser.add_argument("--device", type=str, default="cuda") |
|
parser.add_argument("--half", action="store_true") |
|
parser.add_argument("--compile", action="store_true") |
|
parser.add_argument("--max-text-length", type=int, default=0) |
|
parser.add_argument("--listen", type=str, default="127.0.0.1:8080") |
|
parser.add_argument("--workers", type=int, default=1) |
|
parser.add_argument("--api-key", type=str, default=None) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
class MsgPackRequest(HttpRequest): |
|
async def data( |
|
self, |
|
) -> Annotated[ |
|
Any, ContentType("application/msgpack"), ContentType("application/json") |
|
]: |
|
if self.content_type == "application/msgpack": |
|
return ormsgpack.unpackb(await self.body) |
|
|
|
elif self.content_type == "application/json": |
|
return await self.json |
|
|
|
raise HTTPException( |
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE, |
|
headers={"Accept": "application/msgpack, application/json"}, |
|
) |
|
|
|
|
|
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): |
|
for chunk in inference(req, engine): |
|
if isinstance(chunk, bytes): |
|
yield chunk |
|
|
|
|
|
async def buffer_to_async_generator(buffer): |
|
yield buffer |
|
|
|
|
|
def get_content_type(audio_format): |
|
if audio_format == "wav": |
|
return "audio/wav" |
|
elif audio_format == "flac": |
|
return "audio/flac" |
|
elif audio_format == "mp3": |
|
return "audio/mpeg" |
|
else: |
|
return "application/octet-stream" |
|
|