Spaces:
Running
Running
from contextlib import asynccontextmanager | |
from dataclasses import asdict, dataclass | |
from enum import Enum | |
import re | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
AsyncGenerator, | |
Dict, | |
MutableMapping, | |
Optional, | |
cast, | |
) | |
import uuid | |
from asgiref.typing import ( | |
ASGI3Application, | |
ASGIReceiveCallable, | |
ASGIReceiveEvent, | |
ASGISendCallable, | |
ASGISendEvent, | |
Scope as ASGIScope, | |
) | |
from loguru import logger | |
from starlette.requests import Request | |
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE | |
from open_webui.utils.auth import get_current_user, get_http_authorization_cred | |
from open_webui.models.users import UserModel | |
if TYPE_CHECKING: | |
from loguru import Logger | |
class AuditLogEntry: | |
# `Metadata` audit level properties | |
id: str | |
user: dict[str, Any] | |
audit_level: str | |
verb: str | |
request_uri: str | |
user_agent: Optional[str] = None | |
source_ip: Optional[str] = None | |
# `Request` audit level properties | |
request_object: Any = None | |
# `Request Response` level | |
response_object: Any = None | |
response_status_code: Optional[int] = None | |
class AuditLevel(str, Enum): | |
NONE = "NONE" | |
METADATA = "METADATA" | |
REQUEST = "REQUEST" | |
REQUEST_RESPONSE = "REQUEST_RESPONSE" | |
class AuditLogger: | |
""" | |
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. | |
Parameters: | |
logger (Logger): An instance of Loguru’s logger. | |
""" | |
def __init__(self, logger: "Logger"): | |
self.logger = logger.bind(auditable=True) | |
def write( | |
self, | |
audit_entry: AuditLogEntry, | |
*, | |
log_level: str = "INFO", | |
extra: Optional[dict] = None, | |
): | |
entry = asdict(audit_entry) | |
if extra: | |
entry["extra"] = extra | |
self.logger.log( | |
log_level, | |
"", | |
**entry, | |
) | |
class AuditContext: | |
""" | |
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. | |
Attributes: | |
request_body (bytearray): Accumulated request payload. | |
response_body (bytearray): Accumulated response payload. | |
max_body_size (int): Maximum number of bytes to capture. | |
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). | |
""" | |
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): | |
self.request_body = bytearray() | |
self.response_body = bytearray() | |
self.max_body_size = max_body_size | |
self.metadata: Dict[str, Any] = {} | |
def add_request_chunk(self, chunk: bytes): | |
if len(self.request_body) < self.max_body_size: | |
self.request_body.extend( | |
chunk[: self.max_body_size - len(self.request_body)] | |
) | |
def add_response_chunk(self, chunk: bytes): | |
if len(self.response_body) < self.max_body_size: | |
self.response_body.extend( | |
chunk[: self.max_body_size - len(self.response_body)] | |
) | |
class AuditLoggingMiddleware: | |
""" | |
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. | |
""" | |
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} | |
def __init__( | |
self, | |
app: ASGI3Application, | |
*, | |
excluded_paths: Optional[list[str]] = None, | |
max_body_size: int = MAX_BODY_LOG_SIZE, | |
audit_level: AuditLevel = AuditLevel.NONE, | |
) -> None: | |
self.app = app | |
self.audit_logger = AuditLogger(logger) | |
self.excluded_paths = excluded_paths or [] | |
self.max_body_size = max_body_size | |
self.audit_level = audit_level | |
async def __call__( | |
self, | |
scope: ASGIScope, | |
receive: ASGIReceiveCallable, | |
send: ASGISendCallable, | |
) -> None: | |
if scope["type"] != "http": | |
return await self.app(scope, receive, send) | |
request = Request(scope=cast(MutableMapping, scope)) | |
if self._should_skip_auditing(request): | |
return await self.app(scope, receive, send) | |
async with self._audit_context(request) as context: | |
async def send_wrapper(message: ASGISendEvent) -> None: | |
if self.audit_level == AuditLevel.REQUEST_RESPONSE: | |
await self._capture_response(message, context) | |
await send(message) | |
original_receive = receive | |
async def receive_wrapper() -> ASGIReceiveEvent: | |
nonlocal original_receive | |
message = await original_receive() | |
if self.audit_level in ( | |
AuditLevel.REQUEST, | |
AuditLevel.REQUEST_RESPONSE, | |
): | |
await self._capture_request(message, context) | |
return message | |
await self.app(scope, receive_wrapper, send_wrapper) | |
async def _audit_context( | |
self, request: Request | |
) -> AsyncGenerator[AuditContext, None]: | |
""" | |
async context manager that ensures that an audit log entry is recorded after the request is processed. | |
""" | |
context = AuditContext() | |
try: | |
yield context | |
finally: | |
await self._log_audit_entry(request, context) | |
async def _get_authenticated_user(self, request: Request) -> UserModel: | |
auth_header = request.headers.get("Authorization") | |
assert auth_header | |
user = get_current_user(request, None, get_http_authorization_cred(auth_header)) | |
return user | |
def _should_skip_auditing(self, request: Request) -> bool: | |
if ( | |
request.method not in {"POST", "PUT", "PATCH", "DELETE"} | |
or AUDIT_LOG_LEVEL == "NONE" | |
or not request.headers.get("authorization") | |
): | |
return True | |
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/... | |
pattern = re.compile( | |
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" | |
) | |
if pattern.match(request.url.path): | |
return True | |
return False | |
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): | |
if message["type"] == "http.request": | |
body = message.get("body", b"") | |
context.add_request_chunk(body) | |
async def _capture_response(self, message: ASGISendEvent, context: AuditContext): | |
if message["type"] == "http.response.start": | |
context.metadata["response_status_code"] = message["status"] | |
elif message["type"] == "http.response.body": | |
body = message.get("body", b"") | |
context.add_response_chunk(body) | |
async def _log_audit_entry(self, request: Request, context: AuditContext): | |
try: | |
user = await self._get_authenticated_user(request) | |
entry = AuditLogEntry( | |
id=str(uuid.uuid4()), | |
user=user.model_dump(include={"id", "name", "email", "role"}), | |
audit_level=self.audit_level.value, | |
verb=request.method, | |
request_uri=str(request.url), | |
response_status_code=context.metadata.get("response_status_code", None), | |
source_ip=request.client.host if request.client else None, | |
user_agent=request.headers.get("user-agent"), | |
request_object=context.request_body.decode("utf-8", errors="replace"), | |
response_object=context.response_body.decode("utf-8", errors="replace"), | |
) | |
self.audit_logger.write(entry) | |
except Exception as e: | |
logger.error(f"Failed to log audit entry: {str(e)}") | |