Spaces:
Configuration error
Configuration error
from threading import Lock | |
import pyrootutils | |
import uvicorn | |
from kui.asgi import ( | |
Depends, | |
FactoryClass, | |
HTTPException, | |
HttpRoute, | |
Kui, | |
OpenAPI, | |
Routes, | |
) | |
from kui.cors import CORSConfig | |
from kui.openapi.specification import Info | |
from kui.security import bearer_auth | |
from loguru import logger | |
from typing_extensions import Annotated | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
from tools.server.api_utils import MsgPackRequest, parse_args | |
from tools.server.exception_handler import ExceptionHandler | |
from tools.server.model_manager import ModelManager | |
from tools.server.views import routes | |
class API(ExceptionHandler): | |
def __init__(self): | |
self.args = parse_args() | |
self.routes = routes | |
def api_auth(endpoint): | |
async def verify(token: Annotated[str, Depends(bearer_auth)]): | |
if token != self.args.api_key: | |
raise HTTPException(401, None, "Invalid token") | |
return await endpoint() | |
async def passthrough(): | |
return await endpoint() | |
if self.args.api_key is not None: | |
return verify | |
else: | |
return passthrough | |
self.openapi = OpenAPI( | |
Info( | |
{ | |
"title": "Fish Speech API", | |
"version": "1.5.0", | |
} | |
), | |
).routes | |
# Initialize the app | |
self.app = Kui( | |
routes=self.routes + self.openapi[1:], # Remove the default route | |
exception_handlers={ | |
HTTPException: self.http_exception_handler, | |
Exception: self.other_exception_handler, | |
}, | |
factory_class=FactoryClass(http=MsgPackRequest), | |
cors_config=CORSConfig(), | |
) | |
# Add the state variables | |
self.app.state.lock = Lock() | |
self.app.state.device = self.args.device | |
self.app.state.max_text_length = self.args.max_text_length | |
# Associate the app with the model manager | |
self.app.on_startup(self.initialize_app) | |
async def initialize_app(self, app: Kui): | |
# Make the ModelManager available to the views | |
app.state.model_manager = ModelManager( | |
mode=self.args.mode, | |
device=self.args.device, | |
half=self.args.half, | |
compile=self.args.compile, | |
asr_enabled=self.args.load_asr_model, | |
llama_checkpoint_path=self.args.llama_checkpoint_path, | |
decoder_checkpoint_path=self.args.decoder_checkpoint_path, | |
decoder_config_name=self.args.decoder_config_name, | |
) | |
logger.info(f"Startup done, listening server at http://{self.args.listen}") | |
# Each worker process created by Uvicorn has its own memory space, | |
# meaning that models and variables are not shared between processes. | |
# Therefore, any variables (like `llama_queue` or `decoder_model`) | |
# will not be shared across workers. | |
# Multi-threading for deep learning can cause issues, such as inconsistent | |
# outputs if multiple threads access the same buffers simultaneously. | |
# Instead, it's better to use multiprocessing or independent models per thread. | |
if __name__ == "__main__": | |
api = API() | |
host, port = api.args.listen.split(":") | |
uvicorn.run( | |
api.app, | |
host=host, | |
port=int(port), | |
workers=api.args.workers, | |
log_level="info", | |
) | |