|
from litellm.proxy._types import UserAPIKeyAuth |
|
|
|
|
|
async def check_oauth2_token(token: str) -> UserAPIKeyAuth: |
|
""" |
|
Makes a request to the token info endpoint to validate the OAuth2 token. |
|
|
|
Args: |
|
token (str): The OAuth2 token to validate. |
|
|
|
Returns: |
|
Literal[True]: If the token is valid. |
|
|
|
Raises: |
|
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set. |
|
""" |
|
import os |
|
|
|
import httpx |
|
|
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.proxy._types import CommonProxyErrors |
|
from litellm.proxy.proxy_server import premium_user |
|
|
|
if premium_user is not True: |
|
raise ValueError( |
|
"Oauth2 token validation is only available for premium users" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token) |
|
|
|
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") |
|
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub") |
|
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role") |
|
user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id") |
|
|
|
if not token_info_endpoint: |
|
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") |
|
|
|
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) |
|
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} |
|
|
|
try: |
|
response = await client.get(token_info_endpoint, headers=headers) |
|
|
|
|
|
response.raise_for_status() |
|
|
|
|
|
data = response.json() |
|
|
|
verbose_proxy_logger.debug( |
|
"Oauth2 token validation for token=%s, response from /token/info=%s", |
|
token, |
|
data, |
|
) |
|
|
|
|
|
|
|
user_id = data.get(user_id_field_name) |
|
user_team_id = data.get(user_team_id_field_name) |
|
user_role = data.get(user_role_field_name) |
|
|
|
return UserAPIKeyAuth( |
|
api_key=token, |
|
team_id=user_team_id, |
|
user_id=user_id, |
|
user_role=user_role, |
|
) |
|
except httpx.HTTPStatusError as e: |
|
|
|
raise ValueError(f"Oauth 2.0 Token validation failed: {e}") |
|
except Exception as e: |
|
|
|
raise ValueError(f"An error occurred during token validation: {e}") |
|
|