File size: 2,934 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from litellm.proxy._types import UserAPIKeyAuth


async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
    """
    Makes a request to the token info endpoint to validate the OAuth2 token.

    Args:
    token (str): The OAuth2 token to validate.

    Returns:
    Literal[True]: If the token is valid.

    Raises:
    ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
    """
    import os

    import httpx

    from litellm._logging import verbose_proxy_logger
    from litellm.llms.custom_httpx.http_handler import (
        get_async_httpx_client,
        httpxSpecialProvider,
    )
    from litellm.proxy._types import CommonProxyErrors
    from litellm.proxy.proxy_server import premium_user

    if premium_user is not True:
        raise ValueError(
            "Oauth2 token validation is only available for premium users"
            + CommonProxyErrors.not_premium_user.value
        )

    verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
    # Get the token info endpoint from environment variable
    token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
    user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
    user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
    user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")

    if not token_info_endpoint:
        raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")

    client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}

    try:
        response = await client.get(token_info_endpoint, headers=headers)

        # if it's a bad token we expect it to raise an HTTPStatusError
        response.raise_for_status()

        # If we get here, the request was successful
        data = response.json()

        verbose_proxy_logger.debug(
            "Oauth2 token validation for token=%s, response from /token/info=%s",
            token,
            data,
        )

        # You might want to add additional checks here based on the response
        # For example, checking if the token is expired or has the correct scope
        user_id = data.get(user_id_field_name)
        user_team_id = data.get(user_team_id_field_name)
        user_role = data.get(user_role_field_name)

        return UserAPIKeyAuth(
            api_key=token,
            team_id=user_team_id,
            user_id=user_id,
            user_role=user_role,
        )
    except httpx.HTTPStatusError as e:
        # This will catch any 4xx or 5xx errors
        raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
    except Exception as e:
        # This will catch any other errors (like network issues)
        raise ValueError(f"An error occurred during token validation: {e}")