File size: 9,334 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import traceback
from typing import Optional

import httpx
from fastapi import APIRouter, HTTPException, Request, Response, status

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
    create_pass_through_route,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import *

from .vertex_passthrough_router import VertexPassThroughRouter

router = APIRouter()
vertex_pass_through_router = VertexPassThroughRouter()

default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()


def _get_vertex_env_vars() -> VertexPassThroughCredentials:
    """
    Helper to get vertex pass through config from environment variables

    The following environment variables are used:
    - DEFAULT_VERTEXAI_PROJECT (project id)
    - DEFAULT_VERTEXAI_LOCATION (location)
    - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
    """
    return VertexPassThroughCredentials(
        vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
        vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
        vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
    )


def set_default_vertex_config(config: Optional[dict] = None):
    """Sets vertex configuration from provided config and/or environment variables

    Args:
        config (Optional[dict]): Configuration dictionary
        Example: {
            "vertex_project": "my-project-123",
            "vertex_location": "us-central1",
            "vertex_credentials": "os.environ/GOOGLE_CREDS"
        }
    """
    global default_vertex_config

    # Initialize config dictionary if None
    if config is None:
        default_vertex_config = _get_vertex_env_vars()
        return

    if isinstance(config, dict):
        for key, value in config.items():
            if isinstance(value, str) and value.startswith("os.environ/"):
                config[key] = litellm.get_secret(value)

    _set_default_vertex_config(VertexPassThroughCredentials(**config))


def _set_default_vertex_config(
    vertex_pass_through_credentials: VertexPassThroughCredentials,
):
    global default_vertex_config
    default_vertex_config = vertex_pass_through_credentials


def exception_handler(e: Exception):
    verbose_proxy_logger.error(
        "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format(
            str(e)
        )
    )
    verbose_proxy_logger.debug(traceback.format_exc())
    if isinstance(e, HTTPException):
        return ProxyException(
            message=getattr(e, "message", str(e.detail)),
            type=getattr(e, "type", "None"),
            param=getattr(e, "param", "None"),
            code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
        )
    else:
        error_msg = f"{str(e)}"
        return ProxyException(
            message=getattr(e, "message", error_msg),
            type=getattr(e, "type", "None"),
            param=getattr(e, "param", "None"),
            code=getattr(e, "status_code", 500),
        )


def construct_target_url(
    base_url: str,
    requested_route: str,
    default_vertex_location: Optional[str],
    default_vertex_project: Optional[str],
) -> httpx.URL:
    """
    Allow user to specify their own project id / location.

    If missing, use defaults

    Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460

    Constructed Url:
    POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
    """
    new_base_url = httpx.URL(base_url)
    if "locations" in requested_route:  # contains the target project id + location
        updated_url = new_base_url.copy_with(path=requested_route)
        return updated_url
    """
    - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
    - Add default project id
    - Add default location
    """
    vertex_version: Literal["v1", "v1beta1"] = "v1"
    if "cachedContent" in requested_route:
        vertex_version = "v1beta1"

    base_requested_route = "{}/projects/{}/locations/{}".format(
        vertex_version, default_vertex_project, default_vertex_location
    )

    updated_requested_route = "/" + base_requested_route + requested_route

    updated_url = new_base_url.copy_with(path=updated_requested_route)
    return updated_url


@router.api_route(
    "/vertex-ai/{endpoint:path}",
    methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
    tags=["Vertex AI Pass-through", "pass-through"],
    include_in_schema=False,
)
@router.api_route(
    "/vertex_ai/{endpoint:path}",
    methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
    tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
    endpoint: str,
    request: Request,
    fastapi_response: Response,
):
    """
    Call LiteLLM proxy via Vertex AI SDK.

    [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
    """
    encoded_endpoint = httpx.URL(endpoint).path
    verbose_proxy_logger.debug("requested endpoint %s", endpoint)
    headers: dict = {}
    api_key_to_use = get_litellm_virtual_key(request=request)
    user_api_key_dict = await user_api_key_auth(
        request=request,
        api_key=api_key_to_use,
    )

    vertex_project: Optional[str] = (
        VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
    )
    vertex_location: Optional[str] = (
        VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
    )
    vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
        project_id=vertex_project,
        location=vertex_location,
    )

    # Use headers from the incoming request if no vertex credentials are found
    if vertex_credentials.vertex_project is None:
        headers = dict(request.headers) or {}
        verbose_proxy_logger.debug(
            "default_vertex_config  not set, incoming request headers %s", headers
        )
        base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
        headers.pop("content-length", None)
        headers.pop("host", None)
    else:
        vertex_project = vertex_credentials.vertex_project
        vertex_location = vertex_credentials.vertex_location
        vertex_credentials_str = vertex_credentials.vertex_credentials

        # Construct base URL for the target endpoint
        base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"

        _auth_header, vertex_project = (
            await vertex_fine_tuning_apis_instance._ensure_access_token_async(
                credentials=vertex_credentials_str,
                project_id=vertex_project,
                custom_llm_provider="vertex_ai_beta",
            )
        )

        auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
            model="",
            auth_header=_auth_header,
            gemini_api_key=None,
            vertex_credentials=vertex_credentials_str,
            vertex_project=vertex_project,
            vertex_location=vertex_location,
            stream=False,
            custom_llm_provider="vertex_ai_beta",
            api_base="",
        )

        headers = {
            "Authorization": f"Bearer {auth_header}",
        }

    request_route = encoded_endpoint
    verbose_proxy_logger.debug("request_route %s", request_route)

    # Ensure endpoint starts with '/' for proper URL construction
    if not encoded_endpoint.startswith("/"):
        encoded_endpoint = "/" + encoded_endpoint

    # Construct the full target URL using httpx
    updated_url = construct_target_url(
        base_url=base_target_url,
        requested_route=encoded_endpoint,
        default_vertex_location=vertex_location,
        default_vertex_project=vertex_project,
    )
    # base_url = httpx.URL(base_target_url)
    # updated_url = base_url.copy_with(path=encoded_endpoint)

    verbose_proxy_logger.debug("updated url %s", updated_url)

    ## check for streaming
    target = str(updated_url)
    is_streaming_request = False
    if "stream" in str(updated_url):
        is_streaming_request = True
        target += "?alt=sse"

    ## CREATE PASS-THROUGH
    endpoint_func = create_pass_through_route(
        endpoint=endpoint,
        target=target,
        custom_headers=headers,
    )  # dynamically construct pass-through endpoint based on incoming path
    received_value = await endpoint_func(
        request,
        fastapi_response,
        user_api_key_dict,
        stream=is_streaming_request,  # type: ignore
    )

    return received_value


def get_litellm_virtual_key(request: Request) -> str:
    """
    Extract and format API key from request headers.
    Prioritizes x-litellm-api-key over Authorization header.


    Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key

    """
    litellm_api_key = request.headers.get("x-litellm-api-key")
    if litellm_api_key:
        return f"Bearer {litellm_api_key}"
    return request.headers.get("Authorization", "")