Spaces:
Running
Running
import base64 | |
import random | |
import warnings | |
from collections.abc import Coroutine | |
from datetime import datetime, timedelta, timezone | |
from typing import TYPE_CHECKING, Annotated | |
from uuid import UUID | |
from cryptography.fernet import Fernet | |
from fastapi import Depends, HTTPException, Security, status | |
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer | |
from jose import JWTError, jwt | |
from loguru import logger | |
from sqlmodel.ext.asyncio.session import AsyncSession | |
from starlette.websockets import WebSocket | |
from langflow.services.database.models.api_key.crud import check_key | |
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at | |
from langflow.services.database.models.user.model import User, UserRead | |
from langflow.services.deps import get_db_service, get_session, get_settings_service | |
from langflow.services.settings.service import SettingsService | |
if TYPE_CHECKING: | |
from langflow.services.database.models.api_key.model import ApiKey | |
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) | |
API_KEY_NAME = "x-api-key" | |
api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False) | |
api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False) | |
MINIMUM_KEY_LENGTH = 32 | |
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py | |
async def api_key_security( | |
query_param: Annotated[str, Security(api_key_query)], | |
header_param: Annotated[str, Security(api_key_header)], | |
) -> UserRead | None: | |
settings_service = get_settings_service() | |
result: ApiKey | User | None | |
async with get_db_service().with_async_session() as db: | |
if settings_service.auth_settings.AUTO_LOGIN: | |
# Get the first user | |
if not settings_service.auth_settings.SUPERUSER: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Missing first superuser credentials", | |
) | |
result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER) | |
elif not query_param and not header_param: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="An API key must be passed as query or header", | |
) | |
elif query_param: | |
result = await check_key(db, query_param) | |
else: | |
result = await check_key(db, header_param) | |
if not result: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="Invalid or missing API key", | |
) | |
if isinstance(result, User): | |
return UserRead.model_validate(result, from_attributes=True) | |
msg = "Invalid result type" | |
raise ValueError(msg) | |
async def get_current_user( | |
token: Annotated[str, Security(oauth2_login)], | |
query_param: Annotated[str, Security(api_key_query)], | |
header_param: Annotated[str, Security(api_key_header)], | |
db: Annotated[AsyncSession, Depends(get_session)], | |
) -> User: | |
if token: | |
return await get_current_user_by_jwt(token, db) | |
user = await api_key_security(query_param, header_param) | |
if user: | |
return user | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="Invalid or missing API key", | |
) | |
async def get_current_user_by_jwt( | |
token: str, | |
db: AsyncSession, | |
) -> User: | |
settings_service = get_settings_service() | |
if isinstance(token, Coroutine): | |
token = await token | |
secret_key = settings_service.auth_settings.SECRET_KEY.get_secret_value() | |
if secret_key is None: | |
logger.error("Secret key is not set in settings.") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
# Careful not to leak sensitive information | |
detail="Authentication failure: Verify authentication settings.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM]) | |
user_id: UUID = payload.get("sub") # type: ignore[assignment] | |
token_type: str = payload.get("type") # type: ignore[assignment] | |
if expires := payload.get("exp", None): | |
expires_datetime = datetime.fromtimestamp(expires, timezone.utc) | |
if datetime.now(timezone.utc) > expires_datetime: | |
logger.info("Token expired for user") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Token has expired.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
if user_id is None or token_type is None: | |
logger.info(f"Invalid token payload. Token type: {token_type}") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid token details.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
except JWTError as e: | |
logger.exception("JWT decoding error") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) from e | |
user = await get_user_by_id(db, user_id) | |
if user is None or not user.is_active: | |
logger.info("User not found or inactive.") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="User not found or is inactive.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return user | |
async def get_current_user_for_websocket( | |
websocket: WebSocket, | |
db: Annotated[AsyncSession, Depends(get_session)], | |
query_param: Annotated[str, Security(api_key_query)], | |
) -> User | None: | |
token = websocket.query_params.get("token") | |
api_key = websocket.query_params.get("x-api-key") | |
if token: | |
return await get_current_user_by_jwt(token, db) | |
if api_key: | |
return await api_key_security(api_key, query_param) | |
return None | |
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): | |
if not current_user.is_active: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
return current_user | |
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: | |
if not current_user.is_active: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
if not current_user.is_superuser: | |
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The user doesn't have enough privileges") | |
return current_user | |
def verify_password(plain_password, hashed_password): | |
settings_service = get_settings_service() | |
return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
settings_service = get_settings_service() | |
return settings_service.auth_settings.pwd_context.hash(password) | |
def create_token(data: dict, expires_delta: timedelta): | |
settings_service = get_settings_service() | |
to_encode = data.copy() | |
expire = datetime.now(timezone.utc) + expires_delta | |
to_encode["exp"] = expire | |
return jwt.encode( | |
to_encode, | |
settings_service.auth_settings.SECRET_KEY.get_secret_value(), | |
algorithm=settings_service.auth_settings.ALGORITHM, | |
) | |
async def create_super_user( | |
username: str, | |
password: str, | |
db: AsyncSession, | |
) -> User: | |
super_user = await get_user_by_username(db, username) | |
if not super_user: | |
super_user = User( | |
username=username, | |
password=get_password_hash(password), | |
is_superuser=True, | |
is_active=True, | |
last_login_at=None, | |
) | |
db.add(super_user) | |
await db.commit() | |
await db.refresh(super_user) | |
return super_user | |
async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]: | |
settings_service = get_settings_service() | |
username = settings_service.auth_settings.SUPERUSER | |
super_user = await get_user_by_username(db, username) | |
if not super_user: | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created") | |
access_token_expires_longterm = timedelta(days=365) | |
access_token = create_token( | |
data={"sub": str(super_user.id), "type": "access"}, | |
expires_delta=access_token_expires_longterm, | |
) | |
# Update: last_login_at | |
await update_user_last_login_at(super_user.id, db) | |
return super_user.id, { | |
"access_token": access_token, | |
"refresh_token": None, | |
"token_type": "bearer", | |
} | |
def create_user_api_key(user_id: UUID) -> dict: | |
access_token = create_token( | |
data={"sub": str(user_id), "type": "api_key"}, | |
expires_delta=timedelta(days=365 * 2), | |
) | |
return {"api_key": access_token} | |
def get_user_id_from_token(token: str) -> UUID: | |
try: | |
user_id = jwt.get_unverified_claims(token)["sub"] | |
return UUID(user_id) | |
except (KeyError, JWTError, ValueError): | |
return UUID(int=0) | |
async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict: | |
settings_service = get_settings_service() | |
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS) | |
access_token = create_token( | |
data={"sub": str(user_id), "type": "access"}, | |
expires_delta=access_token_expires, | |
) | |
refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS) | |
refresh_token = create_token( | |
data={"sub": str(user_id), "type": "refresh"}, | |
expires_delta=refresh_token_expires, | |
) | |
# Update: last_login_at | |
if update_last_login: | |
await update_user_last_login_at(user_id, db) | |
return { | |
"access_token": access_token, | |
"refresh_token": refresh_token, | |
"token_type": "bearer", | |
} | |
async def create_refresh_token(refresh_token: str, db: AsyncSession): | |
settings_service = get_settings_service() | |
try: | |
# Ignore warning about datetime.utcnow | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
payload = jwt.decode( | |
refresh_token, | |
settings_service.auth_settings.SECRET_KEY.get_secret_value(), | |
algorithms=[settings_service.auth_settings.ALGORITHM], | |
) | |
user_id: UUID = payload.get("sub") # type: ignore[assignment] | |
token_type: str = payload.get("type") # type: ignore[assignment] | |
if user_id is None or token_type == "": | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") | |
user_exists = await get_user_by_id(db, user_id) | |
if user_exists is None: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") | |
return await create_user_tokens(user_id, db) | |
except JWTError as e: | |
logger.exception("JWT decoding error") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid refresh token", | |
) from e | |
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None: | |
user = await get_user_by_username(db, username) | |
if not user: | |
return None | |
if not user.is_active: | |
if not user.last_login_at: | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Waiting for approval") | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
return user if verify_password(password, user.password) else None | |
def add_padding(s): | |
# Calculate the number of padding characters needed | |
padding_needed = 4 - len(s) % 4 | |
return s + "=" * padding_needed | |
def ensure_valid_key(s: str) -> bytes: | |
# If the key is too short, we'll use it as a seed to generate a valid key | |
if len(s) < MINIMUM_KEY_LENGTH: | |
# Use the input as a seed for the random number generator | |
random.seed(s) | |
# Generate 32 random bytes | |
key = bytes(random.getrandbits(8) for _ in range(32)) | |
key = base64.urlsafe_b64encode(key) | |
else: | |
key = add_padding(s).encode() | |
return key | |
def get_fernet(settings_service: SettingsService): | |
secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value() | |
valid_key = ensure_valid_key(secret_key) | |
return Fernet(valid_key) | |
def encrypt_api_key(api_key: str, settings_service: SettingsService): | |
fernet = get_fernet(settings_service) | |
# Two-way encryption | |
encrypted_key = fernet.encrypt(api_key.encode()) | |
return encrypted_key.decode() | |
def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService): | |
fernet = get_fernet(settings_service) | |
decrypted_key = "" | |
# Two-way decryption | |
if isinstance(encrypted_api_key, str): | |
try: | |
decrypted_key = fernet.decrypt(encrypted_api_key.encode()).decode() | |
except Exception: # noqa: BLE001 | |
logger.debug("Failed to decrypt API key") | |
decrypted_key = fernet.decrypt(encrypted_api_key).decode() | |
return decrypted_key | |