Spaces:
Runtime error
Runtime error
import asyncio | |
import re | |
from contextlib import asynccontextmanager | |
from ipaddress import ip_address | |
from secrets import compare_digest | |
from typing import Annotated | |
import sentry_sdk | |
from fastapi import Depends, FastAPI, Request, Response | |
from fastapi.responses import RedirectResponse | |
from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from sentry_sdk.integrations.logging import LoggingIntegration | |
from hibiapi import __version__ | |
from hibiapi.app.routes import router as ImplRouter | |
from hibiapi.utils.cache import cache | |
from hibiapi.utils.config import Config | |
from hibiapi.utils.exceptions import ClientSideException, RateLimitReachedException | |
from hibiapi.utils.log import logger | |
from hibiapi.utils.net import BaseNetClient | |
from hibiapi.utils.temp import TempFile | |
DESCRIPTION = ( | |
""" | |
**A program that implements easy-to-use APIs for a variety of commonly used sites** | |
- *Documents*: | |
- [Redoc](/docs) (Easier to read and more beautiful) | |
- [Swagger UI](/docs/test) (Integrated interactive testing function) | |
Project: [mixmoe/HibiAPI](https://github.com/mixmoe/HibiAPI) | |
""" | |
+ Config["content"]["slogan"].as_str().strip() | |
).strip() | |
if Config["log"]["sentry"]["enabled"].as_bool(): | |
sentry_sdk.init( | |
dsn=Config["log"]["sentry"]["dsn"].as_str(), | |
send_default_pii=Config["log"]["sentry"]["pii"].as_bool(), | |
integrations=[LoggingIntegration(level=None, event_level=None)], | |
traces_sample_rate=Config["log"]["sentry"]["sample"].get(float), | |
) | |
else: | |
sentry_sdk.init() | |
class AuthorizationModel(BaseModel): | |
username: str | |
password: str | |
AUTHORIZATION_ENABLED = Config["authorization"]["enabled"].as_bool() | |
AUTHORIZATION_ALLOWED = Config["authorization"]["allowed"].get(list[AuthorizationModel]) | |
security = HTTPBasic() | |
async def basic_authorization_depend( | |
credentials: Annotated[HTTPBasicCredentials, Depends(security)], | |
): | |
# NOTE: We use `compare_digest` to avoid timing attacks. | |
# Ref: https://fastapi.tiangolo.com/advanced/security/http-basic-auth/ | |
for allowed in AUTHORIZATION_ALLOWED: | |
if compare_digest(credentials.username, allowed.username) and compare_digest( | |
credentials.password, allowed.password | |
): | |
return credentials.username, credentials.password | |
raise ClientSideException( | |
f"Invalid credentials for user {credentials.username!r}", | |
status_code=401, | |
headers={"WWW-Authenticate": "Basic"}, | |
) | |
RATE_LIMIT_ENABLED = Config["limit"]["enabled"].as_bool() | |
RATE_LIMIT_MAX = Config["limit"]["max"].as_number() | |
RATE_LIMIT_INTERVAL = Config["limit"]["interval"].as_number() | |
async def rate_limit_depend(request: Request): | |
if not request.client: | |
return | |
try: | |
client_ip = ip_address(request.client.host) | |
client_ip_hex = client_ip.packed.hex() | |
limit_key = f"rate_limit:IPv{client_ip.version}-{client_ip_hex:x}" | |
except ValueError: | |
limit_key = f"rate_limit:fallback-{request.client.host}" | |
request_count = await cache.incr(limit_key) | |
if request_count <= 1: | |
await cache.expire(limit_key, timeout=RATE_LIMIT_INTERVAL) | |
elif request_count > RATE_LIMIT_MAX: | |
limit_remain: int = await cache.get_expire(limit_key) | |
raise RateLimitReachedException(headers={"Retry-After": limit_remain}) | |
return | |
async def flush_sentry(): | |
client = sentry_sdk.Hub.current.client | |
if client is not None: | |
client.close() | |
sentry_sdk.flush() | |
logger.debug("Sentry client has been closed") | |
async def cleanup_clients(): | |
opened_clients = [ | |
client for client in BaseNetClient.clients if not client.is_closed | |
] | |
if opened_clients: | |
await asyncio.gather( | |
*map(lambda client: client.aclose(), opened_clients), | |
return_exceptions=True, | |
) | |
logger.debug(f"Cleaned <r>{len(opened_clients)}</r> unclosed HTTP clients") | |
async def fastapi_lifespan(app: FastAPI): | |
yield | |
await asyncio.gather(cleanup_clients(), flush_sentry()) | |
app = FastAPI( | |
title="HibiAPI", | |
version=__version__, | |
description=DESCRIPTION, | |
docs_url="/docs/test", | |
redoc_url="/docs", | |
lifespan=fastapi_lifespan, | |
) | |
app.include_router( | |
ImplRouter, | |
prefix="/api", | |
dependencies=( | |
([Depends(basic_authorization_depend)] if AUTHORIZATION_ENABLED else []) | |
+ ([Depends(rate_limit_depend)] if RATE_LIMIT_ENABLED else []) | |
), | |
) | |
app.mount("/temp", StaticFiles(directory=TempFile.path, check_dir=False)) | |
async def redirect(): | |
return Response(status_code=302, headers={"Location": "/docs"}) | |
async def robots(): | |
content = Config["content"]["robots"].as_str().strip() | |
return Response(content, status_code=200) | |
async def redirect_workaround_middleware(request: Request, call_next): | |
"""Temporary redirection workaround for #12""" | |
if matched := re.match( | |
r"^/(qrcode|pixiv|netease|bilibili)/(\w*)$", request.url.path | |
): | |
service, path = matched.groups() | |
redirect_url = request.url.replace(path=f"/api/{service}/{path}") | |
return RedirectResponse(redirect_url, status_code=301) | |
return await call_next(request) | |