Spaces:
Configuration error
Configuration error
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
import base64 | |
import os | |
import uuid | |
from io import BytesIO | |
from typing import List, Optional | |
from openai import OpenAI | |
from PIL import Image | |
from camel.toolkits import FunctionTool | |
from camel.toolkits.base import BaseToolkit | |
class DalleToolkit(BaseToolkit): | |
r"""A class representing a toolkit for image generation using OpenAI's | |
DALL-E model. | |
""" | |
def base64_to_image(self, base64_string: str) -> Optional[Image.Image]: | |
r"""Converts a base64 encoded string into a PIL Image object. | |
Args: | |
base64_string (str): The base64 encoded string of the image. | |
Returns: | |
Optional[Image.Image]: The PIL Image object or None if conversion | |
fails. | |
""" | |
try: | |
# Decode the base64 string to get the image data | |
image_data = base64.b64decode(base64_string) | |
# Create a memory buffer for the image data | |
image_buffer = BytesIO(image_data) | |
# Open the image using the PIL library | |
image = Image.open(image_buffer) | |
return image | |
except Exception as e: | |
print(f"An error occurred while converting base64 to image: {e}") | |
return None | |
def image_path_to_base64(self, image_path: str) -> str: | |
r"""Converts the file path of an image to a Base64 encoded string. | |
Args: | |
image_path (str): The path to the image file. | |
Returns: | |
str: A Base64 encoded string representing the content of the image | |
file. | |
""" | |
try: | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
except Exception as e: | |
print( | |
f"An error occurred while converting image path to base64: {e}" | |
) | |
return "" | |
def image_to_base64(self, image: Image.Image) -> str: | |
r"""Converts an image into a base64-encoded string. | |
This function takes an image object as input, encodes the image into a | |
PNG format base64 string, and returns it. | |
If the encoding process encounters an error, it prints the error | |
message and returns None. | |
Args: | |
image: The image object to be encoded, supports any image format | |
that can be saved in PNG format. | |
Returns: | |
str: A base64-encoded string of the image. | |
""" | |
try: | |
with BytesIO() as buffered_image: | |
image.save(buffered_image, format="PNG") | |
buffered_image.seek(0) | |
image_bytes = buffered_image.read() | |
base64_str = base64.b64encode(image_bytes).decode('utf-8') | |
return base64_str | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return "" | |
def get_dalle_img(self, prompt: str, image_dir: str = "img") -> str: | |
r"""Generate an image using OpenAI's DALL-E model. | |
The generated image is saved to the specified directory. | |
Args: | |
prompt (str): The text prompt based on which the image is | |
generated. | |
image_dir (str): The directory to save the generated image. | |
Defaults to 'img'. | |
Returns: | |
str: The path to the saved image. | |
""" | |
dalle_client = OpenAI() | |
response = dalle_client.images.generate( | |
model="dall-e-3", | |
prompt=prompt, | |
size="1024x1792", | |
quality="standard", | |
n=1, # NOTE: now dall-e-3 only supports n=1 | |
response_format="b64_json", | |
) | |
image_b64 = response.data[0].b64_json | |
image = self.base64_to_image(image_b64) # type: ignore[arg-type] | |
if image is None: | |
raise ValueError("Failed to convert base64 string to image.") | |
os.makedirs(image_dir, exist_ok=True) | |
image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png") | |
image.save(image_path) | |
return image_path | |
def get_tools(self) -> List[FunctionTool]: | |
r"""Returns a list of FunctionTool objects representing the | |
functions in the toolkit. | |
Returns: | |
List[FunctionTool]: A list of FunctionTool objects | |
representing the functions in the toolkit. | |
""" | |
return [FunctionTool(self.get_dalle_img)] | |