|
""" |
|
OpenAI Image Variations Handler |
|
""" |
|
|
|
from typing import Callable, Optional |
|
|
|
import httpx |
|
from openai import AsyncOpenAI, OpenAI |
|
|
|
import litellm |
|
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders |
|
from litellm.utils import ProviderConfigManager |
|
|
|
from ...base_llm.image_variations.transformation import BaseImageVariationConfig |
|
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj |
|
from ..common_utils import OpenAIError |
|
|
|
|
|
class OpenAIImageVariationsHandler: |
|
def get_sync_client( |
|
self, |
|
client: Optional[OpenAI], |
|
init_client_params: dict, |
|
): |
|
if client is None: |
|
openai_client = OpenAI( |
|
**init_client_params, |
|
) |
|
else: |
|
openai_client = client |
|
return openai_client |
|
|
|
def get_async_client( |
|
self, client: Optional[AsyncOpenAI], init_client_params: dict |
|
) -> AsyncOpenAI: |
|
if client is None: |
|
openai_client = AsyncOpenAI( |
|
**init_client_params, |
|
) |
|
else: |
|
openai_client = client |
|
return openai_client |
|
|
|
async def async_image_variations( |
|
self, |
|
api_key: str, |
|
api_base: str, |
|
organization: Optional[str], |
|
client: Optional[AsyncOpenAI], |
|
data: dict, |
|
headers: dict, |
|
model: Optional[str], |
|
timeout: float, |
|
max_retries: int, |
|
logging_obj: LiteLLMLoggingObj, |
|
model_response: ImageResponse, |
|
optional_params: dict, |
|
litellm_params: dict, |
|
image: FileTypes, |
|
provider_config: BaseImageVariationConfig, |
|
) -> ImageResponse: |
|
try: |
|
init_client_params = { |
|
"api_key": api_key, |
|
"base_url": api_base, |
|
"http_client": litellm.client_session, |
|
"timeout": timeout, |
|
"max_retries": max_retries, |
|
"organization": organization, |
|
} |
|
|
|
client = self.get_async_client( |
|
client=client, init_client_params=init_client_params |
|
) |
|
|
|
raw_response = await client.images.with_raw_response.create_variation(**data) |
|
response = raw_response.parse() |
|
response_json = response.model_dump() |
|
|
|
|
|
logging_obj.post_call( |
|
api_key=api_key, |
|
original_response=response_json, |
|
additional_args={ |
|
"headers": headers, |
|
"api_base": api_base, |
|
}, |
|
) |
|
|
|
|
|
return provider_config.transform_response_image_variation( |
|
model=model, |
|
model_response=ImageResponse(**response_json), |
|
raw_response=httpx.Response( |
|
status_code=200, |
|
request=httpx.Request( |
|
method="GET", url="https://litellm.ai" |
|
), |
|
), |
|
logging_obj=logging_obj, |
|
request_data=data, |
|
image=image, |
|
optional_params=optional_params, |
|
litellm_params=litellm_params, |
|
encoding=None, |
|
api_key=api_key, |
|
) |
|
except Exception as e: |
|
status_code = getattr(e, "status_code", 500) |
|
error_headers = getattr(e, "headers", None) |
|
error_text = getattr(e, "text", str(e)) |
|
error_response = getattr(e, "response", None) |
|
if error_headers is None and error_response: |
|
error_headers = getattr(error_response, "headers", None) |
|
raise OpenAIError( |
|
status_code=status_code, message=error_text, headers=error_headers |
|
) |
|
|
|
def image_variations( |
|
self, |
|
model_response: ImageResponse, |
|
api_key: str, |
|
api_base: str, |
|
model: Optional[str], |
|
image: FileTypes, |
|
timeout: float, |
|
custom_llm_provider: str, |
|
logging_obj: LiteLLMLoggingObj, |
|
optional_params: dict, |
|
litellm_params: dict, |
|
print_verbose: Optional[Callable] = None, |
|
logger_fn=None, |
|
client=None, |
|
organization: Optional[str] = None, |
|
headers: Optional[dict] = None, |
|
) -> ImageResponse: |
|
try: |
|
provider_config = ProviderConfigManager.get_provider_image_variation_config( |
|
model=model or "", |
|
provider=LlmProviders.OPENAI, |
|
) |
|
|
|
if provider_config is None: |
|
raise ValueError( |
|
f"image variation provider not found: {custom_llm_provider}." |
|
) |
|
|
|
max_retries = optional_params.pop("max_retries", 2) |
|
|
|
data = provider_config.transform_request_image_variation( |
|
model=model, |
|
image=image, |
|
optional_params=optional_params, |
|
headers=headers or {}, |
|
) |
|
json_data = data.get("data") |
|
if not json_data: |
|
raise ValueError( |
|
f"data field is required, for openai image variations. Got={data}" |
|
) |
|
|
|
logging_obj.pre_call( |
|
input="", |
|
api_key=api_key, |
|
additional_args={ |
|
"headers": headers, |
|
"api_base": api_base, |
|
"complete_input_dict": data, |
|
}, |
|
) |
|
if litellm_params.get("async_call", False): |
|
return self.async_image_variations( |
|
api_base=api_base, |
|
data=json_data, |
|
headers=headers or {}, |
|
model_response=model_response, |
|
api_key=api_key, |
|
logging_obj=logging_obj, |
|
model=model, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
organization=organization, |
|
client=client, |
|
provider_config=provider_config, |
|
image=image, |
|
optional_params=optional_params, |
|
litellm_params=litellm_params, |
|
) |
|
|
|
init_client_params = { |
|
"api_key": api_key, |
|
"base_url": api_base, |
|
"http_client": litellm.client_session, |
|
"timeout": timeout, |
|
"max_retries": max_retries, |
|
"organization": organization, |
|
} |
|
|
|
client = self.get_sync_client( |
|
client=client, init_client_params=init_client_params |
|
) |
|
|
|
raw_response = client.images.with_raw_response.create_variation(**json_data) |
|
response = raw_response.parse() |
|
response_json = response.model_dump() |
|
|
|
|
|
logging_obj.post_call( |
|
api_key=api_key, |
|
original_response=response_json, |
|
additional_args={ |
|
"headers": headers, |
|
"api_base": api_base, |
|
}, |
|
) |
|
|
|
|
|
return provider_config.transform_response_image_variation( |
|
model=model, |
|
model_response=ImageResponse(**response_json), |
|
raw_response=httpx.Response( |
|
status_code=200, |
|
request=httpx.Request( |
|
method="GET", url="https://litellm.ai" |
|
), |
|
), |
|
logging_obj=logging_obj, |
|
request_data=json_data, |
|
image=image, |
|
optional_params=optional_params, |
|
litellm_params=litellm_params, |
|
encoding=None, |
|
api_key=api_key, |
|
) |
|
except Exception as e: |
|
status_code = getattr(e, "status_code", 500) |
|
error_headers = getattr(e, "headers", None) |
|
error_text = getattr(e, "text", str(e)) |
|
error_response = getattr(e, "response", None) |
|
if error_headers is None and error_response: |
|
error_headers = getattr(error_response, "headers", None) |
|
raise OpenAIError( |
|
status_code=status_code, message=error_text, headers=error_headers |
|
) |
|
|