HibiAPI / hibiapi /app /application.py
DengFengLai's picture
DF.
0a1b571
raw
history blame
5.38 kB
import asyncio
import re
from contextlib import asynccontextmanager
from ipaddress import ip_address
from secrets import compare_digest
from typing import Annotated
import sentry_sdk
from fastapi import Depends, FastAPI, Request, Response
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sentry_sdk.integrations.logging import LoggingIntegration
from hibiapi import __version__
from hibiapi.app.routes import router as ImplRouter
from hibiapi.utils.cache import cache
from hibiapi.utils.config import Config
from hibiapi.utils.exceptions import ClientSideException, RateLimitReachedException
from hibiapi.utils.log import logger
from hibiapi.utils.net import BaseNetClient
from hibiapi.utils.temp import TempFile
DESCRIPTION = (
"""
**A program that implements easy-to-use APIs for a variety of commonly used sites**
- *Documents*:
- [Redoc](/docs) (Easier to read and more beautiful)
- [Swagger UI](/docs/test) (Integrated interactive testing function)
Project: [mixmoe/HibiAPI](https://github.com/mixmoe/HibiAPI)
"""
+ Config["content"]["slogan"].as_str().strip()
).strip()
if Config["log"]["sentry"]["enabled"].as_bool():
sentry_sdk.init(
dsn=Config["log"]["sentry"]["dsn"].as_str(),
send_default_pii=Config["log"]["sentry"]["pii"].as_bool(),
integrations=[LoggingIntegration(level=None, event_level=None)],
traces_sample_rate=Config["log"]["sentry"]["sample"].get(float),
)
else:
sentry_sdk.init()
class AuthorizationModel(BaseModel):
username: str
password: str
AUTHORIZATION_ENABLED = Config["authorization"]["enabled"].as_bool()
AUTHORIZATION_ALLOWED = Config["authorization"]["allowed"].get(list[AuthorizationModel])
security = HTTPBasic()
async def basic_authorization_depend(
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
):
# NOTE: We use `compare_digest` to avoid timing attacks.
# Ref: https://fastapi.tiangolo.com/advanced/security/http-basic-auth/
for allowed in AUTHORIZATION_ALLOWED:
if compare_digest(credentials.username, allowed.username) and compare_digest(
credentials.password, allowed.password
):
return credentials.username, credentials.password
raise ClientSideException(
f"Invalid credentials for user {credentials.username!r}",
status_code=401,
headers={"WWW-Authenticate": "Basic"},
)
RATE_LIMIT_ENABLED = Config["limit"]["enabled"].as_bool()
RATE_LIMIT_MAX = Config["limit"]["max"].as_number()
RATE_LIMIT_INTERVAL = Config["limit"]["interval"].as_number()
async def rate_limit_depend(request: Request):
if not request.client:
return
try:
client_ip = ip_address(request.client.host)
client_ip_hex = client_ip.packed.hex()
limit_key = f"rate_limit:IPv{client_ip.version}-{client_ip_hex:x}"
except ValueError:
limit_key = f"rate_limit:fallback-{request.client.host}"
request_count = await cache.incr(limit_key)
if request_count <= 1:
await cache.expire(limit_key, timeout=RATE_LIMIT_INTERVAL)
elif request_count > RATE_LIMIT_MAX:
limit_remain: int = await cache.get_expire(limit_key)
raise RateLimitReachedException(headers={"Retry-After": limit_remain})
return
async def flush_sentry():
client = sentry_sdk.Hub.current.client
if client is not None:
client.close()
sentry_sdk.flush()
logger.debug("Sentry client has been closed")
async def cleanup_clients():
opened_clients = [
client for client in BaseNetClient.clients if not client.is_closed
]
if opened_clients:
await asyncio.gather(
*map(lambda client: client.aclose(), opened_clients),
return_exceptions=True,
)
logger.debug(f"Cleaned <r>{len(opened_clients)}</r> unclosed HTTP clients")
@asynccontextmanager
async def fastapi_lifespan(app: FastAPI):
yield
await asyncio.gather(cleanup_clients(), flush_sentry())
app = FastAPI(
title="HibiAPI",
version=__version__,
description=DESCRIPTION,
docs_url="/docs/test",
redoc_url="/docs",
lifespan=fastapi_lifespan,
)
app.include_router(
ImplRouter,
prefix="/api",
dependencies=(
([Depends(basic_authorization_depend)] if AUTHORIZATION_ENABLED else [])
+ ([Depends(rate_limit_depend)] if RATE_LIMIT_ENABLED else [])
),
)
app.mount("/temp", StaticFiles(directory=TempFile.path, check_dir=False))
@app.get("/", include_in_schema=False)
async def redirect():
return Response(status_code=302, headers={"Location": "/docs"})
@app.get("/robots.txt", include_in_schema=False)
async def robots():
content = Config["content"]["robots"].as_str().strip()
return Response(content, status_code=200)
@app.middleware("http")
async def redirect_workaround_middleware(request: Request, call_next):
"""Temporary redirection workaround for #12"""
if matched := re.match(
r"^/(qrcode|pixiv|netease|bilibili)/(\w*)$", request.url.path
):
service, path = matched.groups()
redirect_url = request.url.replace(path=f"/api/{service}/{path}")
return RedirectResponse(redirect_url, status_code=301)
return await call_next(request)