Tai Truong
fix readme
d202ada
import base64
import random
import warnings
from collections.abc import Coroutine
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated
from uuid import UUID
from cryptography.fernet import Fernet
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
from jose import JWTError, jwt
from loguru import logger
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.websockets import WebSocket
from langflow.services.database.models.api_key.crud import check_key
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_db_service, get_session, get_settings_service
from langflow.services.settings.service import SettingsService
if TYPE_CHECKING:
from langflow.services.database.models.api_key.model import ApiKey
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
API_KEY_NAME = "x-api-key"
api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False)
api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False)
MINIMUM_KEY_LENGTH = 32
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
async def api_key_security(
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
) -> UserRead | None:
settings_service = get_settings_service()
result: ApiKey | User | None
async with get_db_service().with_async_session() as db:
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing first superuser credentials",
)
result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
elif not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
elif query_param:
result = await check_key(db, query_param)
else:
result = await check_key(db, header_param)
if not result:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
if isinstance(result, User):
return UserRead.model_validate(result, from_attributes=True)
msg = "Invalid result type"
raise ValueError(msg)
async def get_current_user(
token: Annotated[str, Security(oauth2_login)],
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[AsyncSession, Depends(get_session)],
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
user = await api_key_security(query_param, header_param)
if user:
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
async def get_current_user_by_jwt(
token: str,
db: AsyncSession,
) -> User:
settings_service = get_settings_service()
if isinstance(token, Coroutine):
token = await token
secret_key = settings_service.auth_settings.SECRET_KEY.get_secret_value()
if secret_key is None:
logger.error("Secret key is not set in settings.")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
# Careful not to leak sensitive information
detail="Authentication failure: Verify authentication settings.",
headers={"WWW-Authenticate": "Bearer"},
)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM])
user_id: UUID = payload.get("sub") # type: ignore[assignment]
token_type: str = payload.get("type") # type: ignore[assignment]
if expires := payload.get("exp", None):
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
if datetime.now(timezone.utc) > expires_datetime:
logger.info("Token expired for user")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired.",
headers={"WWW-Authenticate": "Bearer"},
)
if user_id is None or token_type is None:
logger.info(f"Invalid token payload. Token type: {token_type}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token details.",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError as e:
logger.exception("JWT decoding error")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from e
user = await get_user_by_id(db, user_id)
if user is None or not user.is_active:
logger.info("User not found or inactive.")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or is inactive.",
headers={"WWW-Authenticate": "Bearer"},
)
return user
async def get_current_user_for_websocket(
websocket: WebSocket,
db: Annotated[AsyncSession, Depends(get_session)],
query_param: Annotated[str, Security(api_key_query)],
) -> User | None:
token = websocket.query_params.get("token")
api_key = websocket.query_params.get("x-api-key")
if token:
return await get_current_user_by_jwt(token, db)
if api_key:
return await api_key_security(api_key, query_param)
return None
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
return current_user
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
if not current_user.is_superuser:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The user doesn't have enough privileges")
return current_user
def verify_password(plain_password, hashed_password):
settings_service = get_settings_service()
return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
settings_service = get_settings_service()
return settings_service.auth_settings.pwd_context.hash(password)
def create_token(data: dict, expires_delta: timedelta):
settings_service = get_settings_service()
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expires_delta
to_encode["exp"] = expire
return jwt.encode(
to_encode,
settings_service.auth_settings.SECRET_KEY.get_secret_value(),
algorithm=settings_service.auth_settings.ALGORITHM,
)
async def create_super_user(
username: str,
password: str,
db: AsyncSession,
) -> User:
super_user = await get_user_by_username(db, username)
if not super_user:
super_user = User(
username=username,
password=get_password_hash(password),
is_superuser=True,
is_active=True,
last_login_at=None,
)
db.add(super_user)
await db.commit()
await db.refresh(super_user)
return super_user
async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]:
settings_service = get_settings_service()
username = settings_service.auth_settings.SUPERUSER
super_user = await get_user_by_username(db, username)
if not super_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created")
access_token_expires_longterm = timedelta(days=365)
access_token = create_token(
data={"sub": str(super_user.id), "type": "access"},
expires_delta=access_token_expires_longterm,
)
# Update: last_login_at
await update_user_last_login_at(super_user.id, db)
return super_user.id, {
"access_token": access_token,
"refresh_token": None,
"token_type": "bearer",
}
def create_user_api_key(user_id: UUID) -> dict:
access_token = create_token(
data={"sub": str(user_id), "type": "api_key"},
expires_delta=timedelta(days=365 * 2),
)
return {"api_key": access_token}
def get_user_id_from_token(token: str) -> UUID:
try:
user_id = jwt.get_unverified_claims(token)["sub"]
return UUID(user_id)
except (KeyError, JWTError, ValueError):
return UUID(int=0)
async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict:
settings_service = get_settings_service()
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
access_token = create_token(
data={"sub": str(user_id), "type": "access"},
expires_delta=access_token_expires,
)
refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS)
refresh_token = create_token(
data={"sub": str(user_id), "type": "refresh"},
expires_delta=refresh_token_expires,
)
# Update: last_login_at
if update_last_login:
await update_user_last_login_at(user_id, db)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
async def create_refresh_token(refresh_token: str, db: AsyncSession):
settings_service = get_settings_service()
try:
# Ignore warning about datetime.utcnow
with warnings.catch_warnings():
warnings.simplefilter("ignore")
payload = jwt.decode(
refresh_token,
settings_service.auth_settings.SECRET_KEY.get_secret_value(),
algorithms=[settings_service.auth_settings.ALGORITHM],
)
user_id: UUID = payload.get("sub") # type: ignore[assignment]
token_type: str = payload.get("type") # type: ignore[assignment]
if user_id is None or token_type == "":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
user_exists = await get_user_by_id(db, user_id)
if user_exists is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
return await create_user_tokens(user_id, db)
except JWTError as e:
logger.exception("JWT decoding error")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
) from e
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None:
user = await get_user_by_username(db, username)
if not user:
return None
if not user.is_active:
if not user.last_login_at:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Waiting for approval")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
return user if verify_password(password, user.password) else None
def add_padding(s):
# Calculate the number of padding characters needed
padding_needed = 4 - len(s) % 4
return s + "=" * padding_needed
def ensure_valid_key(s: str) -> bytes:
# If the key is too short, we'll use it as a seed to generate a valid key
if len(s) < MINIMUM_KEY_LENGTH:
# Use the input as a seed for the random number generator
random.seed(s)
# Generate 32 random bytes
key = bytes(random.getrandbits(8) for _ in range(32))
key = base64.urlsafe_b64encode(key)
else:
key = add_padding(s).encode()
return key
def get_fernet(settings_service: SettingsService):
secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value()
valid_key = ensure_valid_key(secret_key)
return Fernet(valid_key)
def encrypt_api_key(api_key: str, settings_service: SettingsService):
fernet = get_fernet(settings_service)
# Two-way encryption
encrypted_key = fernet.encrypt(api_key.encode())
return encrypted_key.decode()
def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService):
fernet = get_fernet(settings_service)
decrypted_key = ""
# Two-way decryption
if isinstance(encrypted_api_key, str):
try:
decrypted_key = fernet.decrypt(encrypted_api_key.encode()).decode()
except Exception: # noqa: BLE001
logger.debug("Failed to decrypt API key")
decrypted_key = fernet.decrypt(encrypted_api_key).decode()
return decrypted_key