|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import threading |
|
import time |
|
from typing import Literal, Optional |
|
|
|
import litellm |
|
from litellm.utils import ModelResponse |
|
|
|
|
|
class BudgetManager: |
|
def __init__( |
|
self, |
|
project_name: str, |
|
client_type: str = "local", |
|
api_base: Optional[str] = None, |
|
headers: Optional[dict] = None, |
|
): |
|
self.client_type = client_type |
|
self.project_name = project_name |
|
self.api_base = api_base or "https://api.litellm.ai" |
|
self.headers = headers or {"Content-Type": "application/json"} |
|
|
|
self.load_data() |
|
|
|
def print_verbose(self, print_statement): |
|
try: |
|
if litellm.set_verbose: |
|
import logging |
|
|
|
logging.info(print_statement) |
|
except Exception: |
|
pass |
|
|
|
def load_data(self): |
|
if self.client_type == "local": |
|
|
|
if os.path.isfile("user_cost.json"): |
|
|
|
with open("user_cost.json", "r") as json_file: |
|
self.user_dict = json.load(json_file) |
|
else: |
|
self.print_verbose("User Dictionary not found!") |
|
self.user_dict = {} |
|
self.print_verbose(f"user dict from local: {self.user_dict}") |
|
elif self.client_type == "hosted": |
|
|
|
url = self.api_base + "/get_budget" |
|
data = {"project_name": self.project_name} |
|
response = litellm.module_level_client.post( |
|
url, headers=self.headers, json=data |
|
) |
|
response = response.json() |
|
if response["status"] == "error": |
|
self.user_dict = ( |
|
{} |
|
) |
|
else: |
|
self.user_dict = response["data"] |
|
|
|
def create_budget( |
|
self, |
|
total_budget: float, |
|
user: str, |
|
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, |
|
created_at: float = time.time(), |
|
): |
|
self.user_dict[user] = {"total_budget": total_budget} |
|
if duration is None: |
|
return self.user_dict[user] |
|
|
|
if duration == "daily": |
|
duration_in_days = 1 |
|
elif duration == "weekly": |
|
duration_in_days = 7 |
|
elif duration == "monthly": |
|
duration_in_days = 28 |
|
elif duration == "yearly": |
|
duration_in_days = 365 |
|
else: |
|
raise ValueError( |
|
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" |
|
) |
|
self.user_dict[user] = { |
|
"total_budget": total_budget, |
|
"duration": duration_in_days, |
|
"created_at": created_at, |
|
"last_updated_at": created_at, |
|
} |
|
self._save_data_thread() |
|
return self.user_dict[user] |
|
|
|
def projected_cost(self, model: str, messages: list, user: str): |
|
text = "".join(message["content"] for message in messages) |
|
prompt_tokens = litellm.token_counter(model=model, text=text) |
|
prompt_cost, _ = litellm.cost_per_token( |
|
model=model, prompt_tokens=prompt_tokens, completion_tokens=0 |
|
) |
|
current_cost = self.user_dict[user].get("current_cost", 0) |
|
projected_cost = prompt_cost + current_cost |
|
return projected_cost |
|
|
|
def get_total_budget(self, user: str): |
|
return self.user_dict[user]["total_budget"] |
|
|
|
def update_cost( |
|
self, |
|
user: str, |
|
completion_obj: Optional[ModelResponse] = None, |
|
model: Optional[str] = None, |
|
input_text: Optional[str] = None, |
|
output_text: Optional[str] = None, |
|
): |
|
if model and input_text and output_text: |
|
prompt_tokens = litellm.token_counter( |
|
model=model, messages=[{"role": "user", "content": input_text}] |
|
) |
|
completion_tokens = litellm.token_counter( |
|
model=model, messages=[{"role": "user", "content": output_text}] |
|
) |
|
( |
|
prompt_tokens_cost_usd_dollar, |
|
completion_tokens_cost_usd_dollar, |
|
) = litellm.cost_per_token( |
|
model=model, |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
) |
|
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar |
|
elif completion_obj: |
|
cost = litellm.completion_cost(completion_response=completion_obj) |
|
model = completion_obj[ |
|
"model" |
|
] |
|
else: |
|
raise ValueError( |
|
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager" |
|
) |
|
|
|
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get( |
|
"current_cost", 0 |
|
) |
|
if "model_cost" in self.user_dict[user]: |
|
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][ |
|
"model_cost" |
|
].get(model, 0) |
|
else: |
|
self.user_dict[user]["model_cost"] = {model: cost} |
|
|
|
self._save_data_thread() |
|
return {"user": self.user_dict[user]} |
|
|
|
def get_current_cost(self, user): |
|
return self.user_dict[user].get("current_cost", 0) |
|
|
|
def get_model_cost(self, user): |
|
return self.user_dict[user].get("model_cost", 0) |
|
|
|
def is_valid_user(self, user: str) -> bool: |
|
return user in self.user_dict |
|
|
|
def get_users(self): |
|
return list(self.user_dict.keys()) |
|
|
|
def reset_cost(self, user): |
|
self.user_dict[user]["current_cost"] = 0 |
|
self.user_dict[user]["model_cost"] = {} |
|
return {"user": self.user_dict[user]} |
|
|
|
def reset_on_duration(self, user: str): |
|
|
|
last_updated_at = self.user_dict[user]["last_updated_at"] |
|
current_time = time.time() |
|
|
|
|
|
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 |
|
|
|
|
|
if current_time - last_updated_at >= duration_in_seconds: |
|
|
|
self.reset_cost(user) |
|
self.user_dict[user]["last_updated_at"] = current_time |
|
self._save_data_thread() |
|
|
|
def update_budget_all_users(self): |
|
for user in self.get_users(): |
|
if "duration" in self.user_dict[user]: |
|
self.reset_on_duration(user) |
|
|
|
def _save_data_thread(self): |
|
thread = threading.Thread( |
|
target=self.save_data |
|
) |
|
thread.start() |
|
|
|
def save_data(self): |
|
if self.client_type == "local": |
|
import json |
|
|
|
|
|
with open("user_cost.json", "w") as json_file: |
|
json.dump( |
|
self.user_dict, json_file, indent=4 |
|
) |
|
return {"status": "success"} |
|
elif self.client_type == "hosted": |
|
url = self.api_base + "/set_budget" |
|
data = {"project_name": self.project_name, "user_dict": self.user_dict} |
|
response = litellm.module_level_client.post( |
|
url, headers=self.headers, json=data |
|
) |
|
response = response.json() |
|
return response |
|
|