Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| from __future__ import annotations | |
| import os | |
| from datetime import datetime, timezone | |
| from typing import TYPE_CHECKING | |
| from loguru import logger | |
| from sqlmodel import Session, select | |
| from langflow.services.auth import utils as auth_utils | |
| from langflow.services.base import Service | |
| from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate | |
| from langflow.services.variable.base import VariableService | |
| from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE | |
| if TYPE_CHECKING: | |
| from collections.abc import Sequence | |
| from uuid import UUID | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| from langflow.services.settings.service import SettingsService | |
| class DatabaseVariableService(VariableService, Service): | |
| def __init__(self, settings_service: SettingsService): | |
| self.settings_service = settings_service | |
| async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None: | |
| if not self.settings_service.settings.store_environment_variables: | |
| logger.info("Skipping environment variable storage.") | |
| return | |
| logger.info("Storing environment variables in the database.") | |
| for var_name in self.settings_service.settings.variables_to_get_from_environment: | |
| if var_name in os.environ and os.environ[var_name].strip(): | |
| value = os.environ[var_name].strip() | |
| query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name) | |
| existing = (await session.exec(query)).first() | |
| try: | |
| if existing: | |
| await self.update_variable(user_id, var_name, value, session) | |
| else: | |
| await self.create_variable( | |
| user_id=user_id, | |
| name=var_name, | |
| value=value, | |
| default_fields=[], | |
| type_=CREDENTIAL_TYPE, | |
| session=session, | |
| ) | |
| logger.info(f"Processed {var_name} variable from environment.") | |
| except Exception as e: # noqa: BLE001 | |
| logger.exception(f"Error processing {var_name} variable: {e!s}") | |
| def get_variable( | |
| self, | |
| user_id: UUID | str, | |
| name: str, | |
| field: str, | |
| session: Session, | |
| ) -> str: | |
| # we get the credential from the database | |
| # credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first() | |
| variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first() | |
| if not variable or not variable.value: | |
| msg = f"{name} variable not found." | |
| raise ValueError(msg) | |
| if variable.type == CREDENTIAL_TYPE and field == "session_id": | |
| msg = ( | |
| f"variable {name} of type 'Credential' cannot be used in a Session ID field " | |
| "because its purpose is to prevent the exposure of values." | |
| ) | |
| raise TypeError(msg) | |
| # we decrypt the value | |
| return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service) | |
| async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]: | |
| stmt = select(Variable).where(Variable.user_id == user_id) | |
| return list((await session.exec(stmt)).all()) | |
| def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: | |
| variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all() | |
| return [variable.name for variable in variables if variable] | |
| async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: | |
| variables = await self.get_all(user_id=user_id, session=session) | |
| return [variable.name for variable in variables if variable] | |
| async def update_variable( | |
| self, | |
| user_id: UUID | str, | |
| name: str, | |
| value: str, | |
| session: AsyncSession, | |
| ): | |
| stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name) | |
| variable = (await session.exec(stmt)).first() | |
| if not variable: | |
| msg = f"{name} variable not found." | |
| raise ValueError(msg) | |
| encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service) | |
| variable.value = encrypted | |
| session.add(variable) | |
| await session.commit() | |
| await session.refresh(variable) | |
| return variable | |
| async def update_variable_fields( | |
| self, | |
| user_id: UUID | str, | |
| variable_id: UUID | str, | |
| variable: VariableUpdate, | |
| session: AsyncSession, | |
| ): | |
| query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id) | |
| db_variable = (await session.exec(query)).one() | |
| db_variable.updated_at = datetime.now(timezone.utc) | |
| variable.value = variable.value or "" | |
| encrypted = auth_utils.encrypt_api_key(variable.value, settings_service=self.settings_service) | |
| variable.value = encrypted | |
| variable_data = variable.model_dump(exclude_unset=True) | |
| for key, value in variable_data.items(): | |
| setattr(db_variable, key, value) | |
| session.add(db_variable) | |
| await session.commit() | |
| await session.refresh(db_variable) | |
| return db_variable | |
| async def delete_variable( | |
| self, | |
| user_id: UUID | str, | |
| name: str, | |
| session: AsyncSession, | |
| ) -> None: | |
| stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name) | |
| variable = (await session.exec(stmt)).first() | |
| if not variable: | |
| msg = f"{name} variable not found." | |
| raise ValueError(msg) | |
| await session.delete(variable) | |
| await session.commit() | |
| async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None: | |
| stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id) | |
| variable = (await session.exec(stmt)).first() | |
| if not variable: | |
| msg = f"{variable_id} variable not found." | |
| raise ValueError(msg) | |
| await session.delete(variable) | |
| await session.commit() | |
| async def create_variable( | |
| self, | |
| user_id: UUID | str, | |
| name: str, | |
| value: str, | |
| *, | |
| default_fields: Sequence[str] = (), | |
| type_: str = GENERIC_TYPE, | |
| session: AsyncSession, | |
| ): | |
| variable_base = VariableCreate( | |
| name=name, | |
| type=type_, | |
| value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service), | |
| default_fields=list(default_fields), | |
| ) | |
| variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id}) | |
| session.add(variable) | |
| await session.commit() | |
| await session.refresh(variable) | |
| return variable | |
