Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
8.57 kB
"""
OpenAI Image Variations Handler
"""
from typing import Callable, Optional
import httpx
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
from litellm.utils import ProviderConfigManager
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
from ..common_utils import OpenAIError
class OpenAIImageVariationsHandler:
def get_sync_client(
self,
client: Optional[OpenAI],
init_client_params: dict,
):
if client is None:
openai_client = OpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
def get_async_client(
self, client: Optional[AsyncOpenAI], init_client_params: dict
) -> AsyncOpenAI:
if client is None:
openai_client = AsyncOpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
async def async_image_variations(
self,
api_key: str,
api_base: str,
organization: Optional[str],
client: Optional[AsyncOpenAI],
data: dict,
headers: dict,
model: Optional[str],
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
model_response: ImageResponse,
optional_params: dict,
litellm_params: dict,
image: FileTypes,
provider_config: BaseImageVariationConfig,
) -> ImageResponse:
try:
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_async_client(
client=client, init_client_params=init_client_params
)
raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def image_variations(
self,
model_response: ImageResponse,
api_key: str,
api_base: str,
model: Optional[str],
image: FileTypes,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
print_verbose: Optional[Callable] = None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
) -> ImageResponse:
try:
provider_config = ProviderConfigManager.get_provider_image_variation_config(
model=model or "", # openai defaults to dall-e-2
provider=LlmProviders.OPENAI,
)
if provider_config is None:
raise ValueError(
f"image variation provider not found: {custom_llm_provider}."
)
max_retries = optional_params.pop("max_retries", 2)
data = provider_config.transform_request_image_variation(
model=model,
image=image,
optional_params=optional_params,
headers=headers or {},
)
json_data = data.get("data")
if not json_data:
raise ValueError(
f"data field is required, for openai image variations. Got={data}"
)
## LOGGING
logging_obj.pre_call(
input="",
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if litellm_params.get("async_call", False):
return self.async_image_variations(
api_base=api_base,
data=json_data,
headers=headers or {},
model_response=model_response,
api_key=api_key,
logging_obj=logging_obj,
model=model,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
provider_config=provider_config,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
) # type: ignore
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_sync_client(
client=client, init_client_params=init_client_params
)
raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=json_data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)