import ast import asyncio import json from base64 import b64encode from datetime import datetime from typing import List, Optional import httpx from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import ( ConfigFieldInfo, ConfigFieldUpdate, PassThroughEndpointResponse, PassThroughGenericEndpoint, ProxyException, UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.secret_managers.main import get_secret_str from litellm.types.llms.custom_http import httpxSpecialProvider from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging from .types import EndpointType, PassthroughStandardLoggingPayload router = APIRouter() pass_through_endpoint_logging = PassThroughEndpointLogging() def get_response_body(response: httpx.Response) -> Optional[dict]: try: return response.json() except Exception: return None async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: """ checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc only runs for headers defined on config.yaml example header can be {"Authorization": "bearer os.environ/COHERE_API_KEY"} """ if custom_headers is None: return None headers = {} for key, value in custom_headers.items(): # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys # we can then get the b64 encoded keys here if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": # langfuse requires b64 encoded headers - we construct that here _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] if isinstance( _langfuse_public_key, str ) and _langfuse_public_key.startswith("os.environ/"): _langfuse_public_key = get_secret_str(_langfuse_public_key) if isinstance( _langfuse_secret_key, str ) and _langfuse_secret_key.startswith("os.environ/"): _langfuse_secret_key = get_secret_str(_langfuse_secret_key) headers["Authorization"] = "Basic " + b64encode( f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") ).decode("ascii") else: # for all other headers headers[key] = value if isinstance(value, str) and "os.environ/" in value: verbose_proxy_logger.debug( "pass through endpoint - looking up 'os.environ/' variable" ) # get string section that is os.environ/ start_index = value.find("os.environ/") _variable_name = value[start_index:] verbose_proxy_logger.debug( "pass through endpoint - getting secret for variable name: %s", _variable_name, ) _secret_value = get_secret_str(_variable_name) if _secret_value is not None: new_value = value.replace(_variable_name, _secret_value) headers[key] = new_value return headers async def chat_completion_pass_through_endpoint( # noqa: PLR0915 fastapi_response: Response, request: Request, adapter_id: str, user_api_key_dict: UserAPIKeyAuth, ): from litellm.proxy.proxy_server import ( add_litellm_data_to_request, general_settings, get_custom_headers, llm_router, proxy_config, proxy_logging_obj, user_api_base, user_max_tokens, user_model, user_request_timeout, user_temperature, version, ) data = {} try: body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) except Exception: data = json.loads(body_str) data["adapter_id"] = adapter_id verbose_proxy_logger.debug( "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), ) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args or data["model"] # default passed in http request ) if user_model: data["model"] = user_model data = await add_litellm_data_to_request( data=data, # type: ignore request=request, general_settings=general_settings, user_api_key_dict=user_api_key_dict, version=version, proxy_config=proxy_config, ) # override with user settings, these are params passed via cli if user_temperature: data["temperature"] = user_temperature if user_request_timeout: data["request_timeout"] = user_request_timeout if user_max_tokens: data["max_tokens"] = user_max_tokens if user_api_base: data["api_base"] = user_api_base ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name if data["model"] in litellm.model_alias_map: data["model"] = litellm.model_alias_map[data["model"]] ### CALL HOOKS ### - modify incoming data before calling the model data = await proxy_logging_obj.pre_call_hook( # type: ignore user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key if "api_key" in data: llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias ): # model set in model_group_alias llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router llm_response = asyncio.create_task( llm_router.aadapter_completion(**data, specific_deployment=True) ) elif ( llm_router is not None and data["model"] in llm_router.get_model_ids() ): # model in router model list llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( llm_router is not None and data["model"] not in router_model_names and llm_router.default_deployment is not None ): # model in router deployments, calling a specific deployment on the router llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif user_model is not None: # `litellm --model ` llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={ "error": "completion: Invalid model name passed in model=" + data.get("model", "") }, ) # Await the llm_response task response = await llm_response hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" response_cost = hidden_params.get("response_cost", None) or "" ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( litellm_call_id=data.get("litellm_call_id", ""), status="success" ) ) verbose_proxy_logger.debug("final response: %s", response) fastapi_response.headers.update( get_custom_headers( user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, version=version, response_cost=response_cost, ) ) verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) verbose_proxy_logger.exception( "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( str(e) ) ) error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) def forward_headers_from_request( request: Request, headers: dict, forward_headers: Optional[bool] = False, ): """ Helper to forward headers from original request """ if forward_headers is True: request_headers = dict(request.headers) # Header We Should NOT forward request_headers.pop("content-length", None) request_headers.pop("host", None) # Combine request headers with custom headers headers = {**request_headers, **headers} return headers def get_response_headers( headers: httpx.Headers, litellm_call_id: Optional[str] = None ) -> dict: excluded_headers = {"transfer-encoding", "content-encoding"} return_headers = { key: value for key, value in headers.items() if key.lower() not in excluded_headers } if litellm_call_id: return_headers["x-litellm-call-id"] = litellm_call_id return return_headers def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI elif ("api.anthropic.com") in url: return EndpointType.ANTHROPIC return EndpointType.GENERIC async def pass_through_request( # noqa: PLR0915 request: Request, target: str, custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, custom_body: Optional[dict] = None, forward_headers: Optional[bool] = False, query_params: Optional[dict] = None, stream: Optional[bool] = None, ): try: import uuid from litellm.litellm_core_utils.litellm_logging import Logging from litellm.proxy.proxy_server import proxy_logging_obj url = httpx.URL(target) headers = custom_headers headers = forward_headers_from_request( request=request, headers=headers, forward_headers=forward_headers ) endpoint_type: EndpointType = get_endpoint_type(str(url)) _parsed_body = None if custom_body: _parsed_body = custom_body else: _parsed_body = await _read_request_body(request) verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( url, headers, _parsed_body ) ) ### CALL HOOKS ### - modify incoming data / reject request before calling the model _parsed_body = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=_parsed_body, call_type="pass_through_endpoint", ) async_client_obj = get_async_httpx_client( llm_provider=httpxSpecialProvider.PassThroughEndpoint, params={"timeout": 600}, ) async_client = async_client_obj.client litellm_call_id = str(uuid.uuid4()) # create logging object start_time = datetime.now() logging_obj = Logging( model="unknown", messages=[{"role": "user", "content": json.dumps(_parsed_body)}], stream=False, call_type="pass_through_endpoint", start_time=start_time, litellm_call_id=litellm_call_id, function_id="1245", ) passthrough_logging_payload = PassthroughStandardLoggingPayload( url=str(url), request_body=_parsed_body, ) kwargs = _init_kwargs_for_pass_through_endpoint( user_api_key_dict=user_api_key_dict, _parsed_body=_parsed_body, passthrough_logging_payload=passthrough_logging_payload, litellm_call_id=litellm_call_id, request=request, ) # done for supporting 'parallel_request_limiter.py' with pass-through endpoints logging_obj.update_environment_variables( model="unknown", user="unknown", optional_params={}, litellm_params=kwargs["litellm_params"], call_type="pass_through_endpoint", ) logging_obj.model_call_details["litellm_call_id"] = litellm_call_id # combine url with query params for logging requested_query_params: Optional[dict] = ( query_params or request.query_params.__dict__ ) if requested_query_params == request.query_params.__dict__: requested_query_params = None requested_query_params_str = None if requested_query_params: requested_query_params_str = "&".join( f"{k}={v}" for k, v in requested_query_params.items() ) logging_url = str(url) if requested_query_params_str: if "?" in str(url): logging_url = str(url) + "&" + requested_query_params_str else: logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( input=[{"role": "user", "content": json.dumps(_parsed_body)}], api_key="", additional_args={ "complete_input_dict": _parsed_body, "api_base": str(logging_url), "headers": headers, }, ) if stream: req = async_client.build_request( "POST", url, json=_parsed_body, params=requested_query_params, headers=headers, ) response = await async_client.send(req, stream=stream) try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=await e.response.aread() ) return StreamingResponse( PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), ), headers=get_response_headers( headers=response.headers, litellm_call_id=litellm_call_id, ), status_code=response.status_code, ) verbose_proxy_logger.debug("request method: {}".format(request.method)) verbose_proxy_logger.debug("request url: {}".format(url)) verbose_proxy_logger.debug("request headers: {}".format(headers)) verbose_proxy_logger.debug( "requested_query_params={}".format(requested_query_params) ) verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) response = await async_client.request( method=request.method, url=url, headers=headers, params=requested_query_params, json=_parsed_body, ) verbose_proxy_logger.debug("response.headers= %s", response.headers) if _is_streaming_response(response) is True: try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=await e.response.aread() ) return StreamingResponse( PassThroughStreamingHandler.chunk_processor( response=response, request_body=_parsed_body, litellm_logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), ), headers=get_response_headers( headers=response.headers, litellm_call_id=litellm_call_id, ), status_code=response.status_code, ) try: response.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=e.response.text ) if response.status_code >= 300: raise HTTPException(status_code=response.status_code, detail=response.text) content = await response.aread() ## LOG SUCCESS response_body: Optional[dict] = get_response_body(response) passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() asyncio.create_task( pass_through_endpoint_logging.pass_through_async_success_handler( httpx_response=response, response_body=response_body, url_route=str(url), result="", start_time=start_time, end_time=end_time, logging_obj=logging_obj, cache_hit=False, **kwargs, ) ) return Response( content=content, status_code=response.status_code, headers=get_response_headers( headers=response.headers, litellm_call_id=litellm_call_id, ), ) except Exception as e: verbose_proxy_logger.exception( "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( str(e) ) ) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "message", str(e.detail)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) else: error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) def _init_kwargs_for_pass_through_endpoint( request: Request, user_api_key_dict: UserAPIKeyAuth, passthrough_logging_payload: PassthroughStandardLoggingPayload, _parsed_body: Optional[dict] = None, litellm_call_id: Optional[str] = None, ) -> dict: _parsed_body = _parsed_body or {} _litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) _metadata = { "user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id, "user_api_key_team_id": user_api_key_dict.team_id, "user_api_key_end_user_id": user_api_key_dict.end_user_id, } if _litellm_metadata: _metadata.update(_litellm_metadata) _metadata = _update_metadata_with_tags_in_header( request=request, metadata=_metadata, ) kwargs = { "litellm_params": { "metadata": _metadata, }, "call_type": "pass_through_endpoint", "litellm_call_id": litellm_call_id, "passthrough_logging_payload": passthrough_logging_payload, } return kwargs def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: """ If tags are in the request headers, add them to the metadata Used for google and vertex JS SDKs """ _tags = request.headers.get("tags") if _tags: metadata["tags"] = _tags.split(",") return metadata def create_pass_through_route( endpoint, target: str, custom_headers: Optional[dict] = None, _forward_headers: Optional[bool] = False, dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url import uuid from litellm.proxy.utils import get_instance_fn try: if isinstance(target, CustomLogger): adapter = target else: adapter = get_instance_fn(value=target) adapter_id = str(uuid.uuid4()) litellm.adapters = [{"id": adapter_id, "adapter": adapter}] async def endpoint_func( # type: ignore request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): return await chat_completion_pass_through_endpoint( fastapi_response=fastapi_response, request=request, adapter_id=adapter_id, user_api_key_dict=user_api_key_dict, ) except Exception: verbose_proxy_logger.debug("Defaulting to target being a url.") async def endpoint_func( # type: ignore request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), query_params: Optional[dict] = None, custom_body: Optional[dict] = None, stream: Optional[ bool ] = None, # if pass-through endpoint is a streaming request ): return await pass_through_request( # type: ignore request=request, target=target, custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, query_params=query_params, stream=stream, custom_body=custom_body, ) return endpoint_func def _is_streaming_response(response: httpx.Response) -> bool: _content_type = response.headers.get("content-type") if _content_type is not None and "text/event-stream" in _content_type: return True return False async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes from litellm.proxy.proxy_server import app, premium_user for endpoint in pass_through_endpoints: _target = endpoint.get("target", None) _path = endpoint.get("path", None) _custom_headers = endpoint.get("headers", None) _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) _forward_headers = endpoint.get("forward_headers", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": if premium_user is not True: raise ValueError( "Error Setting Authentication on Pass Through Endpoint: {}".format( CommonProxyErrors.not_premium_user.value ) ) _dependencies = [Depends(user_api_key_auth)] LiteLLMRoutes.openai_routes.value.append(_path) if _target is None: continue verbose_proxy_logger.debug( "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies ) app.add_api_route( # type: ignore path=_path, endpoint=create_pass_through_route( # type: ignore _path, _target, _custom_headers, _forward_headers, _dependencies ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, ) verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) @router.get( "/config/pass_through_endpoint", dependencies=[Depends(user_api_key_auth)], response_model=PassThroughEndpointResponse, ) async def get_pass_through_endpoints( endpoint_id: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ GET configured pass through endpoint. If no endpoint_id given, return all configured endpoints. """ from litellm.proxy.proxy_server import get_config_general_settings ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: return PassThroughEndpointResponse(endpoints=[]) pass_through_endpoint_data: Optional[List] = response.field_value if pass_through_endpoint_data is None: return PassThroughEndpointResponse(endpoints=[]) returned_endpoints = [] if endpoint_id is None: for endpoint in pass_through_endpoint_data: if isinstance(endpoint, dict): returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) elif isinstance(endpoint, PassThroughGenericEndpoint): returned_endpoints.append(endpoint) elif endpoint_id is not None: for endpoint in pass_through_endpoint_data: _endpoint: Optional[PassThroughGenericEndpoint] = None if isinstance(endpoint, dict): _endpoint = PassThroughGenericEndpoint(**endpoint) elif isinstance(endpoint, PassThroughGenericEndpoint): _endpoint = endpoint if _endpoint is not None and _endpoint.path == endpoint_id: returned_endpoints.append(_endpoint) return PassThroughEndpointResponse(endpoints=returned_endpoints) @router.post( "/config/pass_through_endpoint/{endpoint_id}", dependencies=[Depends(user_api_key_auth)], ) async def update_pass_through_endpoints(request: Request, endpoint_id: str): """ Update a pass-through endpoint """ pass @router.post( "/config/pass_through_endpoint", dependencies=[Depends(user_api_key_auth)], ) async def create_pass_through_endpoints( data: PassThroughGenericEndpoint, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Create new pass-through endpoint """ from litellm.proxy.proxy_server import ( get_config_general_settings, update_config_general_settings, ) ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: response = ConfigFieldInfo( field_name="pass_through_endpoints", field_value=None ) ## Update field with new endpoint data_dict = data.model_dump() if response.field_value is None: response.field_value = [data_dict] elif isinstance(response.field_value, List): response.field_value.append(data_dict) ## Update db updated_data = ConfigFieldUpdate( field_name="pass_through_endpoints", field_value=response.field_value, config_type="general_settings", ) await update_config_general_settings( data=updated_data, user_api_key_dict=user_api_key_dict ) @router.delete( "/config/pass_through_endpoint", dependencies=[Depends(user_api_key_auth)], response_model=PassThroughEndpointResponse, ) async def delete_pass_through_endpoints( endpoint_id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Delete a pass-through endpoint Returns - the deleted endpoint """ from litellm.proxy.proxy_server import ( get_config_general_settings, update_config_general_settings, ) ## Get existing pass-through endpoint field value try: response: ConfigFieldInfo = await get_config_general_settings( field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict ) except Exception: response = ConfigFieldInfo( field_name="pass_through_endpoints", field_value=None ) ## Update field by removing endpoint pass_through_endpoint_data: Optional[List] = response.field_value response_obj: Optional[PassThroughGenericEndpoint] = None if response.field_value is None or pass_through_endpoint_data is None: raise HTTPException( status_code=400, detail={"error": "There are no pass-through endpoints setup."}, ) elif isinstance(response.field_value, List): invalid_idx: Optional[int] = None for idx, endpoint in enumerate(pass_through_endpoint_data): _endpoint: Optional[PassThroughGenericEndpoint] = None if isinstance(endpoint, dict): _endpoint = PassThroughGenericEndpoint(**endpoint) elif isinstance(endpoint, PassThroughGenericEndpoint): _endpoint = endpoint if _endpoint is not None and _endpoint.path == endpoint_id: invalid_idx = idx response_obj = _endpoint if invalid_idx is not None: pass_through_endpoint_data.pop(invalid_idx) ## Update db updated_data = ConfigFieldUpdate( field_name="pass_through_endpoints", field_value=pass_through_endpoint_data, config_type="general_settings", ) await update_config_general_settings( data=updated_data, user_api_key_dict=user_api_key_dict ) if response_obj is None: raise HTTPException( status_code=400, detail={ "error": "Endpoint={} was not found in pass-through endpoint list.".format( endpoint_id ) }, ) return PassThroughEndpointResponse(endpoints=[response_obj])