TestLLM / litellm /proxy /proxy_cli.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
27.7 kB
import importlib
import json
import os
import random
import subprocess
import sys
import urllib.parse as urlparse
import click
from dotenv import load_dotenv
sys.path.append(os.getcwd())
config_filename = "litellm.secrets"
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
if litellm_mode == "DEV":
load_dotenv()
from enum import Enum
telemetry = None
class LiteLLMDatabaseConnectionPool(Enum):
database_connection_pool_limit = 10
database_connection_pool_timeout = 60
def append_query_params(url, params) -> str:
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}")
verbose_proxy_logger.debug(f"params: {params}")
parsed_url = urlparse.urlparse(url)
parsed_query = urlparse.parse_qs(parsed_url.query)
parsed_query.update(params)
encoded_query = urlparse.urlencode(parsed_query, doseq=True)
modified_url = urlparse.urlunparse(parsed_url._replace(query=encoded_query))
return modified_url # type: ignore
def run_ollama_serve():
try:
command = ["ollama", "serve"]
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
print( # noqa
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
"""
) # noqa
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
@click.command()
@click.option(
"--host", default="0.0.0.0", help="Host for the server to listen on.", envvar="HOST"
)
@click.option("--port", default=4000, help="Port to bind the server to.", envvar="PORT")
@click.option(
"--num_workers",
default=1,
help="Number of uvicorn / gunicorn workers to spin up. By default, 1 uvicorn is used.",
envvar="NUM_WORKERS",
)
@click.option("--api_base", default=None, help="API base URL.")
@click.option(
"--api_version",
default="2024-07-01-preview",
help="For azure - pass in the api version.",
)
@click.option(
"--model", "-m", default=None, help="The model name to pass to litellm expects"
)
@click.option(
"--alias",
default=None,
help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
)
@click.option(
"--add_key", default=None, help="The model name to pass to litellm expects"
)
@click.option("--headers", default=None, help="headers for the API call")
@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
@click.option(
"--debug",
default=False,
is_flag=True,
type=bool,
help="To debug the input",
envvar="DEBUG",
)
@click.option(
"--detailed_debug",
default=False,
is_flag=True,
type=bool,
help="To view detailed debug logs",
envvar="DETAILED_DEBUG",
)
@click.option(
"--use_queue",
default=False,
is_flag=True,
type=bool,
help="To use celery workers for async endpoints",
)
@click.option(
"--temperature", default=None, type=float, help="Set temperature for the model"
)
@click.option(
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
)
@click.option(
"--request_timeout",
default=None,
type=int,
help="Set timeout in seconds for completion calls",
)
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
@click.option(
"--add_function_to_prompt",
is_flag=True,
help="If function passed but unsupported, pass it as prompt",
)
@click.option(
"--config",
"-c",
default=None,
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
)
@click.option(
"--max_budget",
default=None,
type=float,
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
)
@click.option(
"--telemetry",
default=True,
type=bool,
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
)
@click.option(
"--log_config",
default=None,
type=str,
help="Path to the logging configuration file",
)
@click.option(
"--version",
"-v",
default=False,
is_flag=True,
type=bool,
help="Print LiteLLM version",
)
@click.option(
"--health",
flag_value=True,
help="Make a chat/completions request to all llms in config.yaml",
)
@click.option(
"--test",
flag_value=True,
help="proxy chat completions url to make a test request to",
)
@click.option(
"--test_async",
default=False,
is_flag=True,
help="Calls async endpoints /queue/requests and /queue/response",
)
@click.option(
"--iam_token_db_auth",
default=False,
is_flag=True,
help="Connects to RDS DB with IAM token",
)
@click.option(
"--num_requests",
default=10,
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option(
"--run_gunicorn",
default=False,
is_flag=True,
help="Starts proxy via gunicorn, instead of uvicorn (better for managing multiple workers)",
)
@click.option(
"--run_hypercorn",
default=False,
is_flag=True,
help="Starts proxy via hypercorn, instead of uvicorn (supports HTTP/2)",
)
@click.option(
"--ssl_keyfile_path",
default=None,
type=str,
help="Path to the SSL keyfile. Use this when you want to provide SSL certificate when starting proxy",
envvar="SSL_KEYFILE_PATH",
)
@click.option(
"--ssl_certfile_path",
default=None,
type=str,
help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy",
envvar="SSL_CERTFILE_PATH",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server( # noqa: PLR0915
host,
port,
api_base,
api_version,
model,
alias,
add_key,
headers,
save,
debug,
detailed_debug,
temperature,
max_tokens,
request_timeout,
drop_params,
add_function_to_prompt,
config,
max_budget,
telemetry,
test,
local,
num_workers,
test_async,
iam_token_db_auth,
num_requests,
use_queue,
health,
version,
run_gunicorn,
run_hypercorn,
ssl_keyfile_path,
ssl_certfile_path,
log_config,
):
args = locals()
if local:
from proxy_server import (
KeyManagementSettings,
ProxyConfig,
app,
save_worker_config,
)
else:
try:
from .proxy_server import (
KeyManagementSettings,
ProxyConfig,
app,
save_worker_config,
)
except ImportError as e:
if "litellm[proxy]" in str(e):
# user is missing a proxy dependency, ask them to pip install litellm[proxy]
raise e
else:
# this is just a local/relative import error, user git cloned litellm
from proxy_server import (
KeyManagementSettings,
ProxyConfig,
app,
save_worker_config,
)
if version is True:
pkg_version = importlib.metadata.version("litellm") # type: ignore
click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
return
if model and "ollama" in model and api_base is None:
run_ollama_serve()
import httpx
if test_async is True:
import concurrent
import time
api_base = f"http://{host}:{port}"
def _make_openai_completion():
data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Write a short poem about the moon"}
],
}
response = httpx.post("http://0.0.0.0:4000/queue/request", json=data)
response = response.json()
while True:
try:
url = response["url"]
polling_url = f"{api_base}{url}"
polling_response = httpx.get(polling_url)
polling_response = polling_response.json()
print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa
status = polling_response["status"]
if status == "finished":
polling_response["result"]
break
print( # noqa
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}" # noqa
) # noqa
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e) # noqa
break
# Number of concurrent calls (you can adjust this)
concurrent_calls = num_requests
# List to store the futures of concurrent calls
futures = []
start_time = time.time()
# Make concurrent calls
with concurrent.futures.ThreadPoolExecutor( # type: ignore
max_workers=concurrent_calls
) as executor:
for _ in range(concurrent_calls):
futures.append(executor.submit(_make_openai_completion))
# Wait for all futures to complete
concurrent.futures.wait(futures) # type: ignore
# Summarize the results
successful_calls = 0
failed_calls = 0
for future in futures:
if future.done():
if future.result() is not None:
successful_calls += 1
else:
failed_calls += 1
end_time = time.time()
print(f"Elapsed Time: {end_time-start_time}") # noqa
print(f"Load test Summary:") # noqa
print(f"Total Requests: {concurrent_calls}") # noqa
print(f"Successful Calls: {successful_calls}") # noqa
print(f"Failed Calls: {failed_calls}") # noqa
return
if health is not False:
print("\nLiteLLM: Health Testing models in config") # noqa
response = httpx.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4)) # noqa
return
if test is not False:
request_model = model or "gpt-3.5-turbo"
click.echo(
f"\nLiteLLM: Making a test ChatCompletions request to your proxy. Model={request_model}"
)
import openai
if test is True: # flag value set
api_base = f"http://{host}:{port}"
else:
api_base = test
client = openai.OpenAI(api_key="My API Key", base_url=api_base)
response = client.chat.completions.create(
model=request_model,
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
}
],
max_tokens=256,
)
click.echo(f"\nLiteLLM: response from proxy {response}")
print( # noqa
f"\n LiteLLM: Making a test ChatCompletions + streaming r equest to proxy. Model={request_model}"
)
response = client.chat.completions.create(
model=request_model,
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
}
],
stream=True,
)
for chunk in response:
click.echo(f"LiteLLM: streaming response from proxy {chunk}")
print("\n making completion request to proxy") # noqa
response = client.completions.create(
model=request_model, prompt="this is a test request, write a short poem"
)
print(response) # noqa
return
else:
if headers:
headers = json.loads(headers)
save_worker_config(
model=model,
alias=alias,
api_base=api_base,
api_version=api_version,
debug=debug,
detailed_debug=detailed_debug,
temperature=temperature,
max_tokens=max_tokens,
request_timeout=request_timeout,
max_budget=max_budget,
telemetry=telemetry,
drop_params=drop_params,
add_function_to_prompt=add_function_to_prompt,
headers=headers,
save=save,
config=config,
use_queue=use_queue,
)
try:
import uvicorn
if os.name == "nt":
pass
else:
import gunicorn.app.base
except Exception:
raise ImportError(
"uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
)
db_connection_pool_limit = 100
db_connection_timeout = 60
general_settings = {}
### GET DB TOKEN FOR IAM AUTH ###
if iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT")
db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA")
token = generate_iam_auth_token(
db_host=db_host, db_port=db_port, db_user=db_user
)
# print(f"token: {token}")
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
if db_schema:
_db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url
os.environ["IAM_TOKEN_DB_AUTH"] = "True"
### DECRYPT ENV VAR ###
from litellm.secret_managers.aws_secret_manager import decrypt_env_var
if (
os.getenv("USE_AWS_KMS", None) is not None
and os.getenv("USE_AWS_KMS") == "True"
):
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
if config is not None:
"""
Allow user to pass in db url via config
read from there and save it to os.env['DATABASE_URL']
"""
try:
import asyncio
except Exception:
raise ImportError(
"yaml needs to be imported. Run - `pip install 'litellm[proxy]'`"
)
proxy_config = ProxyConfig()
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
### LITELLM SETTINGS ###
litellm_settings = _config.get("litellm_settings", None)
if (
litellm_settings is not None
and "json_logs" in litellm_settings
and litellm_settings["json_logs"] is True
):
import litellm
litellm.json_logs = True
litellm._turn_on_json()
### GENERAL SETTINGS ###
general_settings = _config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get(
"key_management_system", None
)
proxy_config.initialize_secret_manager(key_management_system)
key_management_settings = general_settings.get(
"key_management_settings", None
)
if key_management_settings is not None:
import litellm
litellm._key_management_settings = KeyManagementSettings(
**key_management_settings
)
database_url = general_settings.get("database_url", None)
if database_url is None:
# Check if all required variables are provided
database_host = os.getenv("DATABASE_HOST")
database_username = os.getenv("DATABASE_USERNAME")
database_password = os.getenv("DATABASE_PASSWORD")
database_name = os.getenv("DATABASE_NAME")
if (
database_host
and database_username
and database_password
and database_name
):
# Construct DATABASE_URL from the provided variables
database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}"
os.environ["DATABASE_URL"] = database_url
db_connection_pool_limit = general_settings.get(
"database_connection_pool_limit",
LiteLLMDatabaseConnectionPool.database_connection_pool_limit.value,
)
db_connection_timeout = general_settings.get(
"database_connection_timeout",
LiteLLMDatabaseConnectionPool.database_connection_pool_timeout.value,
)
if database_url and database_url.startswith("os.environ/"):
original_dir = os.getcwd()
# set the working directory to where this script is
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path - for litellm local dev
import litellm
from litellm import get_secret_str
database_url = get_secret_str(database_url, default_value=None)
os.chdir(original_dir)
if database_url is not None and isinstance(database_url, str):
os.environ["DATABASE_URL"] = database_url
if (
os.getenv("DATABASE_URL", None) is not None
or os.getenv("DIRECT_URL", None) is not None
):
try:
from litellm.secret_managers.main import get_secret
if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args
params = {
"connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout,
}
database_url = get_secret("DATABASE_URL", default_value=None)
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
if os.getenv("DIRECT_URL", None) is not None:
### add connection pool + pool timeout args
params = {
"connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout,
}
database_url = os.getenv("DIRECT_URL")
modified_url = append_query_params(database_url, params)
os.environ["DIRECT_URL"] = modified_url
###
subprocess.run(["prisma"], capture_output=True)
is_prisma_runnable = True
except FileNotFoundError:
is_prisma_runnable = False
if is_prisma_runnable:
from litellm.proxy.db.check_migration import check_prisma_schema_diff
from litellm.proxy.db.prisma_client import should_update_schema
if (
should_update_schema(
general_settings.get("disable_prisma_schema_update")
)
is False
):
check_prisma_schema_diff(db_url=None)
else:
for _ in range(4):
# run prisma db push, before starting server
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
)
break # Exit the loop if the subprocess succeeds
except subprocess.CalledProcessError as e:
import time
print(f"Error: {e}") # noqa
time.sleep(random.randrange(start=1, stop=5))
finally:
os.chdir(original_dir)
else:
print( # noqa
f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa
)
if port == 4000 and is_port_in_use(port):
port = random.randint(1024, 49152)
import litellm
if detailed_debug is True:
litellm._turn_on_debug()
# DO NOT DELETE - enables global variables to work across files
from litellm.proxy.proxy_server import app # noqa
uvicorn_args = {
"app": "litellm.proxy.proxy_server:app",
"host": host,
"port": port,
}
if log_config is not None:
print(f"Using log_config: {log_config}") # noqa
uvicorn_args["log_config"] = log_config
elif litellm.json_logs:
print("Using json logs. Setting log_config to None.") # noqa
uvicorn_args["log_config"] = None
if run_gunicorn is False and run_hypercorn is False:
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
print( # noqa
f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa
)
uvicorn_args["ssl_keyfile"] = ssl_keyfile_path
uvicorn_args["ssl_certfile"] = ssl_certfile_path
uvicorn.run(
**uvicorn_args,
loop="uvloop",
workers=num_workers,
)
elif run_gunicorn is True:
# Gunicorn Application Class
class StandaloneApplication(gunicorn.app.base.BaseApplication):
def __init__(self, app, options=None):
self.options = options or {} # gunicorn options
self.application = app # FastAPI app
super().__init__()
_endpoint_str = (
f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\"
)
curl_command = (
_endpoint_str
+ """
--header 'Content-Type: application/json' \\
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
\n
"""
)
print() # noqa
print( # noqa
'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n'
)
print( # noqa
f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n"
)
print( # noqa
"\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n"
) # noqa
print( # noqa
f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n"
) # noqa
def load_config(self):
# note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
if self.cfg is not None:
config = {
key: value
for key, value in self.options.items()
if key in self.cfg.settings and value is not None
}
else:
config = {}
for key, value in config.items():
if self.cfg is not None:
self.cfg.set(key.lower(), value)
def load(self):
# gunicorn app function
return self.application
print( # noqa
f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} with {num_workers} workers\033[0m\n" # noqa
)
gunicorn_options = {
"bind": f"{host}:{port}",
"workers": num_workers, # default is 1
"worker_class": "uvicorn.workers.UvicornWorker",
"preload": True, # Add the preload flag,
"accesslog": "-", # Log to stdout
"timeout": 600, # default to very high number, bedrock/anthropic.claude-v2:1 can take 30+ seconds for the 1st chunk to come in
"access_log_format": '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s',
}
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
print( # noqa
f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa
)
gunicorn_options["certfile"] = ssl_certfile_path
gunicorn_options["keyfile"] = ssl_keyfile_path
StandaloneApplication(
app=app, options=gunicorn_options
).run() # Run gunicorn
elif run_hypercorn is True:
import asyncio
from hypercorn.asyncio import serve
from hypercorn.config import Config
print( # noqa
f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} using Hypercorn\033[0m\n" # noqa
) # noqa
config = Config()
config.bind = [f"{host}:{port}"]
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
print( # noqa
f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa
)
config.certfile = ssl_certfile_path
config.keyfile = ssl_keyfile_path
# hypercorn serve raises a type warning when passing a fast api app - even though fast API is a valid type
asyncio.run(serve(app, config)) # type: ignore
if __name__ == "__main__":
run_server()