File size: 9,334 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import traceback
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException, Request, Response, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import *
from .vertex_passthrough_router import VertexPassThroughRouter
router = APIRouter()
vertex_pass_through_router = VertexPassThroughRouter()
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
def _get_vertex_env_vars() -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
global default_vertex_config
# Initialize config dictionary if None
if config is None:
default_vertex_config = _get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value)
_set_default_vertex_config(VertexPassThroughCredentials(**config))
def _set_default_vertex_config(
vertex_pass_through_credentials: VertexPassThroughCredentials,
):
global default_vertex_config
default_vertex_config = vertex_pass_through_credentials
def exception_handler(e: Exception):
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
return ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
return ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def construct_target_url(
base_url: str,
requested_route: str,
default_vertex_location: Optional[str],
default_vertex_project: Optional[str],
) -> httpx.URL:
"""
Allow user to specify their own project id / location.
If missing, use defaults
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
Constructed Url:
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
"""
new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route)
return updated_url
"""
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id
- Add default location
"""
vertex_version: Literal["v1", "v1beta1"] = "v1"
if "cachedContent" in requested_route:
vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location
)
updated_requested_route = "/" + base_requested_route + requested_route
updated_url = new_base_url.copy_with(path=updated_requested_route)
return updated_url
@router.api_route(
"/vertex-ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
include_in_schema=False,
)
@router.api_route(
"/vertex_ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = (
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
)
vertex_location: Optional[str] = (
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
)
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
)
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
def get_litellm_virtual_key(request: Request) -> str:
"""
Extract and format API key from request headers.
Prioritizes x-litellm-api-key over Authorization header.
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
"""
litellm_api_key = request.headers.get("x-litellm-api-key")
if litellm_api_key:
return f"Bearer {litellm_api_key}"
return request.headers.get("Authorization", "")
|