|
import json |
|
import time |
|
from typing import AsyncIterator, Iterator, List, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator |
|
from litellm.llms.base_llm.chat.transformation import ( |
|
BaseConfig, |
|
BaseLLMException, |
|
LiteLLMLoggingObj, |
|
) |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ( |
|
ChatCompletionToolCallChunk, |
|
ChatCompletionUsageBlock, |
|
GenericStreamingChunk, |
|
ModelResponse, |
|
Usage, |
|
) |
|
|
|
|
|
class CloudflareError(BaseLLMException): |
|
def __init__(self, status_code, message): |
|
self.status_code = status_code |
|
self.message = message |
|
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com") |
|
self.response = httpx.Response(status_code=status_code, request=self.request) |
|
super().__init__( |
|
status_code=status_code, |
|
message=message, |
|
request=self.request, |
|
response=self.response, |
|
) |
|
|
|
|
|
class CloudflareChatConfig(BaseConfig): |
|
max_tokens: Optional[int] = None |
|
stream: Optional[bool] = None |
|
|
|
def __init__( |
|
self, |
|
max_tokens: Optional[int] = None, |
|
stream: Optional[bool] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
if api_key is None: |
|
raise ValueError( |
|
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params" |
|
) |
|
headers = { |
|
"accept": "application/json", |
|
"content-type": "apbplication/json", |
|
"Authorization": "Bearer " + api_key, |
|
} |
|
return headers |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: str, |
|
model: str, |
|
optional_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
return api_base + model |
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]: |
|
return [ |
|
"stream", |
|
"max_tokens", |
|
] |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
supported_openai_params = self.get_supported_openai_params(model=model) |
|
for param, value in non_default_params.items(): |
|
if param == "max_completion_tokens": |
|
optional_params["max_tokens"] = value |
|
elif param in supported_openai_params: |
|
optional_params[param] = value |
|
return optional_params |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
config = litellm.CloudflareChatConfig.get_config() |
|
for k, v in config.items(): |
|
if k not in optional_params: |
|
optional_params[k] = v |
|
|
|
data = { |
|
"messages": messages, |
|
**optional_params, |
|
} |
|
return data |
|
|
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: ModelResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: str, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
completion_response = raw_response.json() |
|
|
|
model_response.choices[0].message.content = completion_response["result"][ |
|
"response" |
|
] |
|
|
|
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model) |
|
completion_tokens = len( |
|
encoding.encode(model_response["choices"][0]["message"].get("content", "")) |
|
) |
|
|
|
model_response.created = int(time.time()) |
|
model_response.model = "cloudflare/" + model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
return model_response |
|
|
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] |
|
) -> BaseLLMException: |
|
return CloudflareError( |
|
status_code=status_code, |
|
message=error_message, |
|
) |
|
|
|
def get_model_response_iterator( |
|
self, |
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], |
|
sync_stream: bool, |
|
json_mode: Optional[bool] = False, |
|
): |
|
return CloudflareChatResponseIterator( |
|
streaming_response=streaming_response, |
|
sync_stream=sync_stream, |
|
json_mode=json_mode, |
|
) |
|
|
|
|
|
class CloudflareChatResponseIterator(BaseModelResponseIterator): |
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: |
|
try: |
|
text = "" |
|
tool_use: Optional[ChatCompletionToolCallChunk] = None |
|
is_finished = False |
|
finish_reason = "" |
|
usage: Optional[ChatCompletionUsageBlock] = None |
|
provider_specific_fields = None |
|
|
|
index = int(chunk.get("index", 0)) |
|
|
|
if "response" in chunk: |
|
text = chunk["response"] |
|
|
|
returned_chunk = GenericStreamingChunk( |
|
text=text, |
|
tool_use=tool_use, |
|
is_finished=is_finished, |
|
finish_reason=finish_reason, |
|
usage=usage, |
|
index=index, |
|
provider_specific_fields=provider_specific_fields, |
|
) |
|
|
|
return returned_chunk |
|
|
|
except json.JSONDecodeError: |
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}") |
|
|