Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
9.33 kB
import traceback
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException, Request, Response, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import *
from .vertex_passthrough_router import VertexPassThroughRouter
router = APIRouter()
vertex_pass_through_router = VertexPassThroughRouter()
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
def _get_vertex_env_vars() -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
global default_vertex_config
# Initialize config dictionary if None
if config is None:
default_vertex_config = _get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value)
_set_default_vertex_config(VertexPassThroughCredentials(**config))
def _set_default_vertex_config(
vertex_pass_through_credentials: VertexPassThroughCredentials,
):
global default_vertex_config
default_vertex_config = vertex_pass_through_credentials
def exception_handler(e: Exception):
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
return ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
return ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def construct_target_url(
base_url: str,
requested_route: str,
default_vertex_location: Optional[str],
default_vertex_project: Optional[str],
) -> httpx.URL:
"""
Allow user to specify their own project id / location.
If missing, use defaults
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
Constructed Url:
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
"""
new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route)
return updated_url
"""
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id
- Add default location
"""
vertex_version: Literal["v1", "v1beta1"] = "v1"
if "cachedContent" in requested_route:
vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location
)
updated_requested_route = "/" + base_requested_route + requested_route
updated_url = new_base_url.copy_with(path=updated_requested_route)
return updated_url
@router.api_route(
"/vertex-ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
include_in_schema=False,
)
@router.api_route(
"/vertex_ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = (
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
)
vertex_location: Optional[str] = (
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
)
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
)
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
def get_litellm_virtual_key(request: Request) -> str:
"""
Extract and format API key from request headers.
Prioritizes x-litellm-api-key over Authorization header.
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
"""
litellm_api_key = request.headers.get("x-litellm-api-key")
if litellm_api_key:
return f"Bearer {litellm_api_key}"
return request.headers.get("Authorization", "")