Tai Truong
fix readme
d202ada
from __future__ import annotations
import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from loguru import logger
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.settings.service import SettingsService
class DatabaseVariableService(VariableService, Service):
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
if not self.settings_service.settings.store_environment_variables:
logger.info("Skipping environment variable storage.")
return
logger.info("Storing environment variables in the database.")
for var_name in self.settings_service.settings.variables_to_get_from_environment:
if var_name in os.environ and os.environ[var_name].strip():
value = os.environ[var_name].strip()
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name)
existing = (await session.exec(query)).first()
try:
if existing:
await self.update_variable(user_id, var_name, value, session)
else:
await self.create_variable(
user_id=user_id,
name=var_name,
value=value,
default_fields=[],
type_=CREDENTIAL_TYPE,
session=session,
)
logger.info(f"Processed {var_name} variable from environment.")
except Exception as e: # noqa: BLE001
logger.exception(f"Error processing {var_name} variable: {e!s}")
def get_variable(
self,
user_id: UUID | str,
name: str,
field: str,
session: Session,
) -> str:
# we get the credential from the database
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
if not variable or not variable.value:
msg = f"{name} variable not found."
raise ValueError(msg)
if variable.type == CREDENTIAL_TYPE and field == "session_id":
msg = (
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
"because its purpose is to prevent the exposure of values."
)
raise TypeError(msg)
# we decrypt the value
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
stmt = select(Variable).where(Variable.user_id == user_id)
return list((await session.exec(stmt)).all())
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables if variable]
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
variables = await self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]
async def update_variable(
self,
user_id: UUID | str,
name: str,
value: str,
session: AsyncSession,
):
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name)
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{name} variable not found."
raise ValueError(msg)
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service)
variable.value = encrypted
session.add(variable)
await session.commit()
await session.refresh(variable)
return variable
async def update_variable_fields(
self,
user_id: UUID | str,
variable_id: UUID | str,
variable: VariableUpdate,
session: AsyncSession,
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = (await session.exec(query)).one()
db_variable.updated_at = datetime.now(timezone.utc)
variable.value = variable.value or ""
encrypted = auth_utils.encrypt_api_key(variable.value, settings_service=self.settings_service)
variable.value = encrypted
variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
session.add(db_variable)
await session.commit()
await session.refresh(db_variable)
return db_variable
async def delete_variable(
self,
user_id: UUID | str,
name: str,
session: AsyncSession,
) -> None:
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{name} variable not found."
raise ValueError(msg)
await session.delete(variable)
await session.commit()
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{variable_id} variable not found."
raise ValueError(msg)
await session.delete(variable)
await session.commit()
async def create_variable(
self,
user_id: UUID | str,
name: str,
value: str,
*,
default_fields: Sequence[str] = (),
type_: str = GENERIC_TYPE,
session: AsyncSession,
):
variable_base = VariableCreate(
name=name,
type=type_,
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
default_fields=list(default_fields),
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
session.add(variable)
await session.commit()
await session.refresh(variable)
return variable