File size: 6,494 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
from typing import Any, Optional, Union
import httpx
def init_rds_client(
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
from litellm.secret_managers.main import get_secret
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param) # type: ignore
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### SET REGION NAME
region_name = aws_region_name
if aws_region_name:
region_name = aws_region_name
elif litellm_aws_region_name:
region_name = litellm_aws_region_name
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise Exception(
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
)
import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config( # type: ignore
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config() # type: ignore
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
try:
oidc_token = open(aws_web_identity_token).read() # check if filepath
except Exception:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise Exception(
"OIDC token could not be retrieved from secret manager.",
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
client = boto3.client(
service_name="rds",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
config=config,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="rds",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
config=config,
)
elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name).client(
service_name="rds",
region_name=region_name,
config=config,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables
client = boto3.client(
service_name="rds",
region_name=region_name,
config=config,
)
return client
def generate_iam_auth_token(
db_host, db_port, db_user, client: Optional[Any] = None
) -> str:
from urllib.parse import quote
if client is None:
boto_client = init_rds_client(
aws_region_name=os.getenv("AWS_REGION_NAME"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_name=os.getenv("AWS_SESSION_NAME"),
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
aws_web_identity_token=os.getenv(
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
),
)
else:
boto_client = client
token = boto_client.generate_db_auth_token(
DBHostname=db_host, Port=db_port, DBUsername=db_user
)
cleaned_token = quote(token, safe="")
return cleaned_token
|