|
""" |
|
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token. |
|
""" |
|
|
|
import asyncio |
|
import os |
|
import urllib |
|
import urllib.parse |
|
from datetime import datetime, timedelta |
|
from typing import Any, Optional |
|
|
|
from litellm.secret_managers.main import str_to_bool |
|
|
|
|
|
class PrismaWrapper: |
|
def __init__(self, original_prisma: Any, iam_token_db_auth: bool): |
|
self._original_prisma = original_prisma |
|
self.iam_token_db_auth = iam_token_db_auth |
|
|
|
def is_token_expired(self, token_url: Optional[str]) -> bool: |
|
if token_url is None: |
|
return True |
|
|
|
decoded_url = urllib.parse.unquote(token_url) |
|
|
|
|
|
parsed_url = urllib.parse.urlparse(decoded_url) |
|
|
|
|
|
query_params = urllib.parse.parse_qs(parsed_url.query) |
|
|
|
|
|
expires = query_params.get("X-Amz-Expires", [None])[0] |
|
if expires is None: |
|
raise ValueError("X-Amz-Expires parameter is missing or invalid.") |
|
|
|
expires_int = int(expires) |
|
|
|
|
|
token_time_str = query_params.get("X-Amz-Date", [""])[0] |
|
if not token_time_str: |
|
raise ValueError("X-Amz-Date parameter is missing or invalid.") |
|
|
|
|
|
try: |
|
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ") |
|
except ValueError as e: |
|
raise ValueError(f"Invalid X-Amz-Date format: {e}") |
|
|
|
|
|
expiration_time = token_time + timedelta(seconds=expires_int) |
|
|
|
|
|
current_time = datetime.utcnow() |
|
|
|
|
|
return current_time > expiration_time |
|
|
|
def get_rds_iam_token(self) -> Optional[str]: |
|
if self.iam_token_db_auth: |
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token |
|
|
|
db_host = os.getenv("DATABASE_HOST") |
|
db_port = os.getenv("DATABASE_PORT") |
|
db_user = os.getenv("DATABASE_USER") |
|
db_name = os.getenv("DATABASE_NAME") |
|
db_schema = os.getenv("DATABASE_SCHEMA") |
|
|
|
token = generate_iam_auth_token( |
|
db_host=db_host, db_port=db_port, db_user=db_user |
|
) |
|
|
|
|
|
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}" |
|
if db_schema: |
|
_db_url += f"?schema={db_schema}" |
|
|
|
os.environ["DATABASE_URL"] = _db_url |
|
return _db_url |
|
return None |
|
|
|
async def recreate_prisma_client( |
|
self, new_db_url: str, http_client: Optional[Any] = None |
|
): |
|
from prisma import Prisma |
|
|
|
if http_client is not None: |
|
self._original_prisma = Prisma(http=http_client) |
|
else: |
|
self._original_prisma = Prisma() |
|
|
|
await self._original_prisma.connect() |
|
|
|
def __getattr__(self, name: str): |
|
original_attr = getattr(self._original_prisma, name) |
|
if self.iam_token_db_auth: |
|
db_url = os.getenv("DATABASE_URL") |
|
if self.is_token_expired(db_url): |
|
db_url = self.get_rds_iam_token() |
|
loop = asyncio.get_event_loop() |
|
|
|
if db_url: |
|
if loop.is_running(): |
|
asyncio.run_coroutine_threadsafe( |
|
self.recreate_prisma_client(db_url), loop |
|
) |
|
else: |
|
asyncio.run(self.recreate_prisma_client(db_url)) |
|
else: |
|
raise ValueError("Failed to get RDS IAM token") |
|
|
|
return original_attr |
|
|
|
|
|
def should_update_schema(disable_prisma_schema_update: Optional[bool]): |
|
""" |
|
This function is used to determine if the Prisma schema should be updated. |
|
""" |
|
if disable_prisma_schema_update is None: |
|
disable_prisma_schema_update = str_to_bool(os.getenv("DISABLE_SCHEMA_UPDATE")) |
|
if disable_prisma_schema_update is True: |
|
return False |
|
return True |
|
|