|
import ast |
|
import asyncio |
|
import json |
|
from base64 import b64encode |
|
from datetime import datetime |
|
from typing import List, Optional |
|
|
|
import httpx |
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status |
|
from fastapi.responses import StreamingResponse |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client |
|
from litellm.proxy._types import ( |
|
ConfigFieldInfo, |
|
ConfigFieldUpdate, |
|
PassThroughEndpointResponse, |
|
PassThroughGenericEndpoint, |
|
ProxyException, |
|
UserAPIKeyAuth, |
|
) |
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth |
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body |
|
from litellm.secret_managers.main import get_secret_str |
|
from litellm.types.llms.custom_http import httpxSpecialProvider |
|
|
|
from .streaming_handler import PassThroughStreamingHandler |
|
from .success_handler import PassThroughEndpointLogging |
|
from .types import EndpointType, PassthroughStandardLoggingPayload |
|
|
|
router = APIRouter() |
|
|
|
pass_through_endpoint_logging = PassThroughEndpointLogging() |
|
|
|
|
|
def get_response_body(response: httpx.Response) -> Optional[dict]: |
|
try: |
|
return response.json() |
|
except Exception: |
|
return None |
|
|
|
|
|
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: |
|
""" |
|
checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc |
|
|
|
only runs for headers defined on config.yaml |
|
|
|
example header can be |
|
|
|
{"Authorization": "bearer os.environ/COHERE_API_KEY"} |
|
""" |
|
if custom_headers is None: |
|
return None |
|
headers = {} |
|
for key, value in custom_headers.items(): |
|
|
|
|
|
if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": |
|
|
|
_langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] |
|
_langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] |
|
if isinstance( |
|
_langfuse_public_key, str |
|
) and _langfuse_public_key.startswith("os.environ/"): |
|
_langfuse_public_key = get_secret_str(_langfuse_public_key) |
|
if isinstance( |
|
_langfuse_secret_key, str |
|
) and _langfuse_secret_key.startswith("os.environ/"): |
|
_langfuse_secret_key = get_secret_str(_langfuse_secret_key) |
|
headers["Authorization"] = "Basic " + b64encode( |
|
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") |
|
).decode("ascii") |
|
else: |
|
|
|
headers[key] = value |
|
if isinstance(value, str) and "os.environ/" in value: |
|
verbose_proxy_logger.debug( |
|
"pass through endpoint - looking up 'os.environ/' variable" |
|
) |
|
|
|
start_index = value.find("os.environ/") |
|
_variable_name = value[start_index:] |
|
|
|
verbose_proxy_logger.debug( |
|
"pass through endpoint - getting secret for variable name: %s", |
|
_variable_name, |
|
) |
|
_secret_value = get_secret_str(_variable_name) |
|
if _secret_value is not None: |
|
new_value = value.replace(_variable_name, _secret_value) |
|
headers[key] = new_value |
|
return headers |
|
|
|
|
|
async def chat_completion_pass_through_endpoint( |
|
fastapi_response: Response, |
|
request: Request, |
|
adapter_id: str, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
): |
|
from litellm.proxy.proxy_server import ( |
|
add_litellm_data_to_request, |
|
general_settings, |
|
get_custom_headers, |
|
llm_router, |
|
proxy_config, |
|
proxy_logging_obj, |
|
user_api_base, |
|
user_max_tokens, |
|
user_model, |
|
user_request_timeout, |
|
user_temperature, |
|
version, |
|
) |
|
|
|
data = {} |
|
try: |
|
body = await request.body() |
|
body_str = body.decode() |
|
try: |
|
data = ast.literal_eval(body_str) |
|
except Exception: |
|
data = json.loads(body_str) |
|
|
|
data["adapter_id"] = adapter_id |
|
|
|
verbose_proxy_logger.debug( |
|
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), |
|
) |
|
data["model"] = ( |
|
general_settings.get("completion_model", None) |
|
or user_model |
|
or data["model"] |
|
) |
|
if user_model: |
|
data["model"] = user_model |
|
|
|
data = await add_litellm_data_to_request( |
|
data=data, |
|
request=request, |
|
general_settings=general_settings, |
|
user_api_key_dict=user_api_key_dict, |
|
version=version, |
|
proxy_config=proxy_config, |
|
) |
|
|
|
|
|
if user_temperature: |
|
data["temperature"] = user_temperature |
|
if user_request_timeout: |
|
data["request_timeout"] = user_request_timeout |
|
if user_max_tokens: |
|
data["max_tokens"] = user_max_tokens |
|
if user_api_base: |
|
data["api_base"] = user_api_base |
|
|
|
|
|
|
|
|
|
if data["model"] in litellm.model_alias_map: |
|
data["model"] = litellm.model_alias_map[data["model"]] |
|
|
|
|
|
data = await proxy_logging_obj.pre_call_hook( |
|
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" |
|
) |
|
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else [] |
|
|
|
if "api_key" in data: |
|
llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) |
|
elif ( |
|
llm_router is not None and data["model"] in router_model_names |
|
): |
|
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) |
|
elif ( |
|
llm_router is not None |
|
and llm_router.model_group_alias is not None |
|
and data["model"] in llm_router.model_group_alias |
|
): |
|
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) |
|
elif ( |
|
llm_router is not None and data["model"] in llm_router.deployment_names |
|
): |
|
llm_response = asyncio.create_task( |
|
llm_router.aadapter_completion(**data, specific_deployment=True) |
|
) |
|
elif ( |
|
llm_router is not None and data["model"] in llm_router.get_model_ids() |
|
): |
|
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) |
|
elif ( |
|
llm_router is not None |
|
and data["model"] not in router_model_names |
|
and llm_router.default_deployment is not None |
|
): |
|
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) |
|
elif user_model is not None: |
|
llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) |
|
else: |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail={ |
|
"error": "completion: Invalid model name passed in model=" |
|
+ data.get("model", "") |
|
}, |
|
) |
|
|
|
|
|
response = await llm_response |
|
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {} |
|
model_id = hidden_params.get("model_id", None) or "" |
|
cache_key = hidden_params.get("cache_key", None) or "" |
|
api_base = hidden_params.get("api_base", None) or "" |
|
response_cost = hidden_params.get("response_cost", None) or "" |
|
|
|
|
|
asyncio.create_task( |
|
proxy_logging_obj.update_request_status( |
|
litellm_call_id=data.get("litellm_call_id", ""), status="success" |
|
) |
|
) |
|
|
|
verbose_proxy_logger.debug("final response: %s", response) |
|
|
|
fastapi_response.headers.update( |
|
get_custom_headers( |
|
user_api_key_dict=user_api_key_dict, |
|
model_id=model_id, |
|
cache_key=cache_key, |
|
api_base=api_base, |
|
version=version, |
|
response_cost=response_cost, |
|
) |
|
) |
|
|
|
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) |
|
return response |
|
except Exception as e: |
|
await proxy_logging_obj.post_call_failure_hook( |
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data |
|
) |
|
verbose_proxy_logger.exception( |
|
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format( |
|
str(e) |
|
) |
|
) |
|
error_msg = f"{str(e)}" |
|
raise ProxyException( |
|
message=getattr(e, "message", error_msg), |
|
type=getattr(e, "type", "None"), |
|
param=getattr(e, "param", "None"), |
|
code=getattr(e, "status_code", 500), |
|
) |
|
|
|
|
|
def forward_headers_from_request( |
|
request: Request, |
|
headers: dict, |
|
forward_headers: Optional[bool] = False, |
|
): |
|
""" |
|
Helper to forward headers from original request |
|
""" |
|
if forward_headers is True: |
|
request_headers = dict(request.headers) |
|
|
|
|
|
request_headers.pop("content-length", None) |
|
request_headers.pop("host", None) |
|
|
|
|
|
headers = {**request_headers, **headers} |
|
return headers |
|
|
|
|
|
def get_response_headers( |
|
headers: httpx.Headers, litellm_call_id: Optional[str] = None |
|
) -> dict: |
|
excluded_headers = {"transfer-encoding", "content-encoding"} |
|
|
|
return_headers = { |
|
key: value |
|
for key, value in headers.items() |
|
if key.lower() not in excluded_headers |
|
} |
|
if litellm_call_id: |
|
return_headers["x-litellm-call-id"] = litellm_call_id |
|
|
|
return return_headers |
|
|
|
|
|
def get_endpoint_type(url: str) -> EndpointType: |
|
if ("generateContent") in url or ("streamGenerateContent") in url: |
|
return EndpointType.VERTEX_AI |
|
elif ("api.anthropic.com") in url: |
|
return EndpointType.ANTHROPIC |
|
return EndpointType.GENERIC |
|
|
|
|
|
async def pass_through_request( |
|
request: Request, |
|
target: str, |
|
custom_headers: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
custom_body: Optional[dict] = None, |
|
forward_headers: Optional[bool] = False, |
|
query_params: Optional[dict] = None, |
|
stream: Optional[bool] = None, |
|
): |
|
try: |
|
import uuid |
|
|
|
from litellm.litellm_core_utils.litellm_logging import Logging |
|
from litellm.proxy.proxy_server import proxy_logging_obj |
|
|
|
url = httpx.URL(target) |
|
headers = custom_headers |
|
headers = forward_headers_from_request( |
|
request=request, headers=headers, forward_headers=forward_headers |
|
) |
|
|
|
endpoint_type: EndpointType = get_endpoint_type(str(url)) |
|
|
|
_parsed_body = None |
|
if custom_body: |
|
_parsed_body = custom_body |
|
else: |
|
_parsed_body = await _read_request_body(request) |
|
verbose_proxy_logger.debug( |
|
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( |
|
url, headers, _parsed_body |
|
) |
|
) |
|
|
|
|
|
_parsed_body = await proxy_logging_obj.pre_call_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
data=_parsed_body, |
|
call_type="pass_through_endpoint", |
|
) |
|
async_client_obj = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.PassThroughEndpoint, |
|
params={"timeout": 600}, |
|
) |
|
async_client = async_client_obj.client |
|
|
|
litellm_call_id = str(uuid.uuid4()) |
|
|
|
|
|
start_time = datetime.now() |
|
logging_obj = Logging( |
|
model="unknown", |
|
messages=[{"role": "user", "content": json.dumps(_parsed_body)}], |
|
stream=False, |
|
call_type="pass_through_endpoint", |
|
start_time=start_time, |
|
litellm_call_id=litellm_call_id, |
|
function_id="1245", |
|
) |
|
passthrough_logging_payload = PassthroughStandardLoggingPayload( |
|
url=str(url), |
|
request_body=_parsed_body, |
|
) |
|
kwargs = _init_kwargs_for_pass_through_endpoint( |
|
user_api_key_dict=user_api_key_dict, |
|
_parsed_body=_parsed_body, |
|
passthrough_logging_payload=passthrough_logging_payload, |
|
litellm_call_id=litellm_call_id, |
|
request=request, |
|
) |
|
|
|
logging_obj.update_environment_variables( |
|
model="unknown", |
|
user="unknown", |
|
optional_params={}, |
|
litellm_params=kwargs["litellm_params"], |
|
call_type="pass_through_endpoint", |
|
) |
|
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id |
|
|
|
|
|
|
|
requested_query_params: Optional[dict] = ( |
|
query_params or request.query_params.__dict__ |
|
) |
|
if requested_query_params == request.query_params.__dict__: |
|
requested_query_params = None |
|
|
|
requested_query_params_str = None |
|
if requested_query_params: |
|
requested_query_params_str = "&".join( |
|
f"{k}={v}" for k, v in requested_query_params.items() |
|
) |
|
|
|
logging_url = str(url) |
|
if requested_query_params_str: |
|
if "?" in str(url): |
|
logging_url = str(url) + "&" + requested_query_params_str |
|
else: |
|
logging_url = str(url) + "?" + requested_query_params_str |
|
|
|
logging_obj.pre_call( |
|
input=[{"role": "user", "content": json.dumps(_parsed_body)}], |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": _parsed_body, |
|
"api_base": str(logging_url), |
|
"headers": headers, |
|
}, |
|
) |
|
if stream: |
|
req = async_client.build_request( |
|
"POST", |
|
url, |
|
json=_parsed_body, |
|
params=requested_query_params, |
|
headers=headers, |
|
) |
|
|
|
response = await async_client.send(req, stream=stream) |
|
|
|
try: |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException( |
|
status_code=e.response.status_code, detail=await e.response.aread() |
|
) |
|
|
|
return StreamingResponse( |
|
PassThroughStreamingHandler.chunk_processor( |
|
response=response, |
|
request_body=_parsed_body, |
|
litellm_logging_obj=logging_obj, |
|
endpoint_type=endpoint_type, |
|
start_time=start_time, |
|
passthrough_success_handler_obj=pass_through_endpoint_logging, |
|
url_route=str(url), |
|
), |
|
headers=get_response_headers( |
|
headers=response.headers, |
|
litellm_call_id=litellm_call_id, |
|
), |
|
status_code=response.status_code, |
|
) |
|
|
|
verbose_proxy_logger.debug("request method: {}".format(request.method)) |
|
verbose_proxy_logger.debug("request url: {}".format(url)) |
|
verbose_proxy_logger.debug("request headers: {}".format(headers)) |
|
verbose_proxy_logger.debug( |
|
"requested_query_params={}".format(requested_query_params) |
|
) |
|
verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) |
|
|
|
response = await async_client.request( |
|
method=request.method, |
|
url=url, |
|
headers=headers, |
|
params=requested_query_params, |
|
json=_parsed_body, |
|
) |
|
|
|
verbose_proxy_logger.debug("response.headers= %s", response.headers) |
|
|
|
if _is_streaming_response(response) is True: |
|
try: |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException( |
|
status_code=e.response.status_code, detail=await e.response.aread() |
|
) |
|
|
|
return StreamingResponse( |
|
PassThroughStreamingHandler.chunk_processor( |
|
response=response, |
|
request_body=_parsed_body, |
|
litellm_logging_obj=logging_obj, |
|
endpoint_type=endpoint_type, |
|
start_time=start_time, |
|
passthrough_success_handler_obj=pass_through_endpoint_logging, |
|
url_route=str(url), |
|
), |
|
headers=get_response_headers( |
|
headers=response.headers, |
|
litellm_call_id=litellm_call_id, |
|
), |
|
status_code=response.status_code, |
|
) |
|
|
|
try: |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException( |
|
status_code=e.response.status_code, detail=e.response.text |
|
) |
|
|
|
if response.status_code >= 300: |
|
raise HTTPException(status_code=response.status_code, detail=response.text) |
|
|
|
content = await response.aread() |
|
|
|
|
|
response_body: Optional[dict] = get_response_body(response) |
|
passthrough_logging_payload["response_body"] = response_body |
|
end_time = datetime.now() |
|
asyncio.create_task( |
|
pass_through_endpoint_logging.pass_through_async_success_handler( |
|
httpx_response=response, |
|
response_body=response_body, |
|
url_route=str(url), |
|
result="", |
|
start_time=start_time, |
|
end_time=end_time, |
|
logging_obj=logging_obj, |
|
cache_hit=False, |
|
**kwargs, |
|
) |
|
) |
|
|
|
return Response( |
|
content=content, |
|
status_code=response.status_code, |
|
headers=get_response_headers( |
|
headers=response.headers, |
|
litellm_call_id=litellm_call_id, |
|
), |
|
) |
|
except Exception as e: |
|
verbose_proxy_logger.exception( |
|
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( |
|
str(e) |
|
) |
|
) |
|
if isinstance(e, HTTPException): |
|
raise ProxyException( |
|
message=getattr(e, "message", str(e.detail)), |
|
type=getattr(e, "type", "None"), |
|
param=getattr(e, "param", "None"), |
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), |
|
) |
|
else: |
|
error_msg = f"{str(e)}" |
|
raise ProxyException( |
|
message=getattr(e, "message", error_msg), |
|
type=getattr(e, "type", "None"), |
|
param=getattr(e, "param", "None"), |
|
code=getattr(e, "status_code", 500), |
|
) |
|
|
|
|
|
def _init_kwargs_for_pass_through_endpoint( |
|
request: Request, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
passthrough_logging_payload: PassthroughStandardLoggingPayload, |
|
_parsed_body: Optional[dict] = None, |
|
litellm_call_id: Optional[str] = None, |
|
) -> dict: |
|
_parsed_body = _parsed_body or {} |
|
_litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) |
|
_metadata = { |
|
"user_api_key": user_api_key_dict.api_key, |
|
"user_api_key_user_id": user_api_key_dict.user_id, |
|
"user_api_key_team_id": user_api_key_dict.team_id, |
|
"user_api_key_end_user_id": user_api_key_dict.end_user_id, |
|
} |
|
if _litellm_metadata: |
|
_metadata.update(_litellm_metadata) |
|
|
|
_metadata = _update_metadata_with_tags_in_header( |
|
request=request, |
|
metadata=_metadata, |
|
) |
|
|
|
kwargs = { |
|
"litellm_params": { |
|
"metadata": _metadata, |
|
}, |
|
"call_type": "pass_through_endpoint", |
|
"litellm_call_id": litellm_call_id, |
|
"passthrough_logging_payload": passthrough_logging_payload, |
|
} |
|
return kwargs |
|
|
|
|
|
def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: |
|
""" |
|
If tags are in the request headers, add them to the metadata |
|
|
|
Used for google and vertex JS SDKs |
|
""" |
|
_tags = request.headers.get("tags") |
|
if _tags: |
|
metadata["tags"] = _tags.split(",") |
|
return metadata |
|
|
|
|
|
def create_pass_through_route( |
|
endpoint, |
|
target: str, |
|
custom_headers: Optional[dict] = None, |
|
_forward_headers: Optional[bool] = False, |
|
dependencies: Optional[List] = None, |
|
): |
|
|
|
import uuid |
|
|
|
from litellm.proxy.utils import get_instance_fn |
|
|
|
try: |
|
if isinstance(target, CustomLogger): |
|
adapter = target |
|
else: |
|
adapter = get_instance_fn(value=target) |
|
adapter_id = str(uuid.uuid4()) |
|
litellm.adapters = [{"id": adapter_id, "adapter": adapter}] |
|
|
|
async def endpoint_func( |
|
request: Request, |
|
fastapi_response: Response, |
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), |
|
): |
|
return await chat_completion_pass_through_endpoint( |
|
fastapi_response=fastapi_response, |
|
request=request, |
|
adapter_id=adapter_id, |
|
user_api_key_dict=user_api_key_dict, |
|
) |
|
|
|
except Exception: |
|
verbose_proxy_logger.debug("Defaulting to target being a url.") |
|
|
|
async def endpoint_func( |
|
request: Request, |
|
fastapi_response: Response, |
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), |
|
query_params: Optional[dict] = None, |
|
custom_body: Optional[dict] = None, |
|
stream: Optional[ |
|
bool |
|
] = None, |
|
): |
|
return await pass_through_request( |
|
request=request, |
|
target=target, |
|
custom_headers=custom_headers or {}, |
|
user_api_key_dict=user_api_key_dict, |
|
forward_headers=_forward_headers, |
|
query_params=query_params, |
|
stream=stream, |
|
custom_body=custom_body, |
|
) |
|
|
|
return endpoint_func |
|
|
|
|
|
def _is_streaming_response(response: httpx.Response) -> bool: |
|
_content_type = response.headers.get("content-type") |
|
if _content_type is not None and "text/event-stream" in _content_type: |
|
return True |
|
return False |
|
|
|
|
|
async def initialize_pass_through_endpoints(pass_through_endpoints: list): |
|
|
|
verbose_proxy_logger.debug("initializing pass through endpoints") |
|
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes |
|
from litellm.proxy.proxy_server import app, premium_user |
|
|
|
for endpoint in pass_through_endpoints: |
|
_target = endpoint.get("target", None) |
|
_path = endpoint.get("path", None) |
|
_custom_headers = endpoint.get("headers", None) |
|
_custom_headers = await set_env_variables_in_header( |
|
custom_headers=_custom_headers |
|
) |
|
_forward_headers = endpoint.get("forward_headers", None) |
|
_auth = endpoint.get("auth", None) |
|
_dependencies = None |
|
if _auth is not None and str(_auth).lower() == "true": |
|
if premium_user is not True: |
|
raise ValueError( |
|
"Error Setting Authentication on Pass Through Endpoint: {}".format( |
|
CommonProxyErrors.not_premium_user.value |
|
) |
|
) |
|
_dependencies = [Depends(user_api_key_auth)] |
|
LiteLLMRoutes.openai_routes.value.append(_path) |
|
|
|
if _target is None: |
|
continue |
|
|
|
verbose_proxy_logger.debug( |
|
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies |
|
) |
|
app.add_api_route( |
|
path=_path, |
|
endpoint=create_pass_through_route( |
|
_path, _target, _custom_headers, _forward_headers, _dependencies |
|
), |
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], |
|
dependencies=_dependencies, |
|
) |
|
|
|
verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) |
|
|
|
|
|
@router.get( |
|
"/config/pass_through_endpoint", |
|
dependencies=[Depends(user_api_key_auth)], |
|
response_model=PassThroughEndpointResponse, |
|
) |
|
async def get_pass_through_endpoints( |
|
endpoint_id: Optional[str] = None, |
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), |
|
): |
|
""" |
|
GET configured pass through endpoint. |
|
|
|
If no endpoint_id given, return all configured endpoints. |
|
""" |
|
from litellm.proxy.proxy_server import get_config_general_settings |
|
|
|
|
|
try: |
|
response: ConfigFieldInfo = await get_config_general_settings( |
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict |
|
) |
|
except Exception: |
|
return PassThroughEndpointResponse(endpoints=[]) |
|
|
|
pass_through_endpoint_data: Optional[List] = response.field_value |
|
if pass_through_endpoint_data is None: |
|
return PassThroughEndpointResponse(endpoints=[]) |
|
|
|
returned_endpoints = [] |
|
if endpoint_id is None: |
|
for endpoint in pass_through_endpoint_data: |
|
if isinstance(endpoint, dict): |
|
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) |
|
elif isinstance(endpoint, PassThroughGenericEndpoint): |
|
returned_endpoints.append(endpoint) |
|
elif endpoint_id is not None: |
|
for endpoint in pass_through_endpoint_data: |
|
_endpoint: Optional[PassThroughGenericEndpoint] = None |
|
if isinstance(endpoint, dict): |
|
_endpoint = PassThroughGenericEndpoint(**endpoint) |
|
elif isinstance(endpoint, PassThroughGenericEndpoint): |
|
_endpoint = endpoint |
|
|
|
if _endpoint is not None and _endpoint.path == endpoint_id: |
|
returned_endpoints.append(_endpoint) |
|
|
|
return PassThroughEndpointResponse(endpoints=returned_endpoints) |
|
|
|
|
|
@router.post( |
|
"/config/pass_through_endpoint/{endpoint_id}", |
|
dependencies=[Depends(user_api_key_auth)], |
|
) |
|
async def update_pass_through_endpoints(request: Request, endpoint_id: str): |
|
""" |
|
Update a pass-through endpoint |
|
""" |
|
pass |
|
|
|
|
|
@router.post( |
|
"/config/pass_through_endpoint", |
|
dependencies=[Depends(user_api_key_auth)], |
|
) |
|
async def create_pass_through_endpoints( |
|
data: PassThroughGenericEndpoint, |
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), |
|
): |
|
""" |
|
Create new pass-through endpoint |
|
""" |
|
from litellm.proxy.proxy_server import ( |
|
get_config_general_settings, |
|
update_config_general_settings, |
|
) |
|
|
|
|
|
|
|
try: |
|
response: ConfigFieldInfo = await get_config_general_settings( |
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict |
|
) |
|
except Exception: |
|
response = ConfigFieldInfo( |
|
field_name="pass_through_endpoints", field_value=None |
|
) |
|
|
|
|
|
data_dict = data.model_dump() |
|
if response.field_value is None: |
|
response.field_value = [data_dict] |
|
elif isinstance(response.field_value, List): |
|
response.field_value.append(data_dict) |
|
|
|
|
|
updated_data = ConfigFieldUpdate( |
|
field_name="pass_through_endpoints", |
|
field_value=response.field_value, |
|
config_type="general_settings", |
|
) |
|
await update_config_general_settings( |
|
data=updated_data, user_api_key_dict=user_api_key_dict |
|
) |
|
|
|
|
|
@router.delete( |
|
"/config/pass_through_endpoint", |
|
dependencies=[Depends(user_api_key_auth)], |
|
response_model=PassThroughEndpointResponse, |
|
) |
|
async def delete_pass_through_endpoints( |
|
endpoint_id: str, |
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), |
|
): |
|
""" |
|
Delete a pass-through endpoint |
|
|
|
Returns - the deleted endpoint |
|
""" |
|
from litellm.proxy.proxy_server import ( |
|
get_config_general_settings, |
|
update_config_general_settings, |
|
) |
|
|
|
|
|
|
|
try: |
|
response: ConfigFieldInfo = await get_config_general_settings( |
|
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict |
|
) |
|
except Exception: |
|
response = ConfigFieldInfo( |
|
field_name="pass_through_endpoints", field_value=None |
|
) |
|
|
|
|
|
pass_through_endpoint_data: Optional[List] = response.field_value |
|
response_obj: Optional[PassThroughGenericEndpoint] = None |
|
if response.field_value is None or pass_through_endpoint_data is None: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={"error": "There are no pass-through endpoints setup."}, |
|
) |
|
elif isinstance(response.field_value, List): |
|
invalid_idx: Optional[int] = None |
|
for idx, endpoint in enumerate(pass_through_endpoint_data): |
|
_endpoint: Optional[PassThroughGenericEndpoint] = None |
|
if isinstance(endpoint, dict): |
|
_endpoint = PassThroughGenericEndpoint(**endpoint) |
|
elif isinstance(endpoint, PassThroughGenericEndpoint): |
|
_endpoint = endpoint |
|
|
|
if _endpoint is not None and _endpoint.path == endpoint_id: |
|
invalid_idx = idx |
|
response_obj = _endpoint |
|
|
|
if invalid_idx is not None: |
|
pass_through_endpoint_data.pop(invalid_idx) |
|
|
|
|
|
updated_data = ConfigFieldUpdate( |
|
field_name="pass_through_endpoints", |
|
field_value=pass_through_endpoint_data, |
|
config_type="general_settings", |
|
) |
|
await update_config_general_settings( |
|
data=updated_data, user_api_key_dict=user_api_key_dict |
|
) |
|
|
|
if response_obj is None: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Endpoint={} was not found in pass-through endpoint list.".format( |
|
endpoint_id |
|
) |
|
}, |
|
) |
|
return PassThroughEndpointResponse(endpoints=[response_obj]) |
|
|