Spaces:
Runtime error
Runtime error
import os | |
from functools import wraps | |
from typing import Any, Awaitable, Callable | |
import uvicorn | |
from fastapi import FastAPI, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from inference.core import logger | |
from inference.enterprise.stream_management.api.entities import ( | |
CommandResponse, | |
InferencePipelineStatusResponse, | |
ListPipelinesResponse, | |
PipelineInitialisationRequest, | |
) | |
from inference.enterprise.stream_management.api.errors import ( | |
ConnectivityError, | |
ProcessesManagerAuthorisationError, | |
ProcessesManagerClientError, | |
ProcessesManagerInvalidPayload, | |
ProcessesManagerNotFoundError, | |
) | |
from inference.enterprise.stream_management.api.stream_manager_client import ( | |
StreamManagerClient, | |
) | |
from inference.enterprise.stream_management.manager.entities import ( | |
STATUS_KEY, | |
OperationStatus, | |
) | |
API_HOST = os.getenv("STREAM_MANAGEMENT_API_HOST", "127.0.0.1") | |
API_PORT = int(os.getenv("STREAM_MANAGEMENT_API_PORT", "8080")) | |
OPERATIONS_TIMEOUT = os.getenv("STREAM_MANAGER_OPERATIONS_TIMEOUT") | |
if OPERATIONS_TIMEOUT is not None: | |
OPERATIONS_TIMEOUT = float(OPERATIONS_TIMEOUT) | |
STREAM_MANAGER_CLIENT = StreamManagerClient.init( | |
host=os.getenv("STREAM_MANAGER_HOST", "127.0.0.1"), | |
port=int(os.getenv("STREAM_MANAGER_PORT", "7070")), | |
operations_timeout=OPERATIONS_TIMEOUT, | |
) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def with_route_exceptions(route: callable) -> Callable[[Any], Awaitable[JSONResponse]]: | |
async def wrapped_route(*args, **kwargs): | |
try: | |
return await route(*args, **kwargs) | |
except ProcessesManagerInvalidPayload as error: | |
resp = JSONResponse( | |
status_code=400, | |
content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, | |
) | |
logger.exception("Processes Manager - invalid payload error") | |
return resp | |
except ProcessesManagerAuthorisationError as error: | |
resp = JSONResponse( | |
status_code=401, | |
content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, | |
) | |
logger.exception("Processes Manager - authorisation error") | |
return resp | |
except ProcessesManagerNotFoundError as error: | |
resp = JSONResponse( | |
status_code=404, | |
content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, | |
) | |
logger.exception("Processes Manager - not found error") | |
return resp | |
except ConnectivityError as error: | |
resp = JSONResponse( | |
status_code=503, | |
content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, | |
) | |
logger.exception("Processes Manager connectivity error occurred") | |
return resp | |
except ProcessesManagerClientError as error: | |
resp = JSONResponse( | |
status_code=500, | |
content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, | |
) | |
logger.exception("Processes Manager error occurred") | |
return resp | |
except Exception: | |
resp = JSONResponse( | |
status_code=500, | |
content={ | |
STATUS_KEY: OperationStatus.FAILURE, | |
"message": "Internal error.", | |
}, | |
) | |
logger.exception("Internal error in API") | |
return resp | |
return wrapped_route | |
async def list_pipelines(_: Request) -> ListPipelinesResponse: | |
return await STREAM_MANAGER_CLIENT.list_pipelines() | |
async def get_status(pipeline_id: str) -> InferencePipelineStatusResponse: | |
return await STREAM_MANAGER_CLIENT.get_status(pipeline_id=pipeline_id) | |
async def initialise(request: PipelineInitialisationRequest) -> CommandResponse: | |
return await STREAM_MANAGER_CLIENT.initialise_pipeline( | |
initialisation_request=request | |
) | |
async def pause(pipeline_id: str) -> CommandResponse: | |
return await STREAM_MANAGER_CLIENT.pause_pipeline(pipeline_id=pipeline_id) | |
async def resume(pipeline_id: str) -> CommandResponse: | |
return await STREAM_MANAGER_CLIENT.resume_pipeline(pipeline_id=pipeline_id) | |
async def terminate(pipeline_id: str) -> CommandResponse: | |
return await STREAM_MANAGER_CLIENT.terminate_pipeline(pipeline_id=pipeline_id) | |
if __name__ == "__main__": | |
uvicorn.run(app, host=API_HOST, port=API_PORT) | |