|
import base64 |
|
from enum import auto, IntEnum |
|
from io import BytesIO |
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
class ImageFormat(IntEnum): |
|
"""Image formats.""" |
|
|
|
URL = auto() |
|
LOCAL_FILEPATH = auto() |
|
PIL_IMAGE = auto() |
|
BYTES = auto() |
|
DEFAULT = auto() |
|
|
|
|
|
class Image(BaseModel): |
|
url: str = "" |
|
filetype: str = "" |
|
image_format: ImageFormat = ImageFormat.BYTES |
|
base64_str: str = "" |
|
|
|
def convert_image_to_base64(self): |
|
"""Given an image, return the base64 encoded image string.""" |
|
from PIL import Image |
|
import requests |
|
|
|
|
|
if self.image_format == ImageFormat.URL: |
|
response = requests.get(image) |
|
image = Image.open(BytesIO(response.content)).convert("RGBA") |
|
image_bytes = BytesIO() |
|
image.save(image_bytes, format="PNG") |
|
elif self.image_format == ImageFormat.LOCAL_FILEPATH: |
|
image = Image.open(self.url).convert("RGBA") |
|
image_bytes = BytesIO() |
|
image.save(image_bytes, format="PNG") |
|
elif self.image_format == ImageFormat.BYTES: |
|
image_bytes = image |
|
|
|
img_b64_str = base64.b64encode(image_bytes).decode() |
|
|
|
return img_b64_str |
|
|
|
def to_openai_image_format(self): |
|
if self.image_format == ImageFormat.URL: |
|
return self.url |
|
elif self.image_format == ImageFormat.LOCAL_FILEPATH: |
|
self.base64_str = self.convert_image_to_base64(self.url) |
|
return f"data:image/{self.filetype};base64,{self.base64_str}" |
|
elif self.image_format == ImageFormat.BYTES: |
|
return f"data:image/{self.filetype};base64,{self.base64_str}" |
|
else: |
|
raise ValueError( |
|
f"This file is not valid or not currently supported by the OpenAI API: {self.url}" |
|
) |
|
|
|
def resize_image_and_return_image_in_bytes(self, image, max_image_size_mb): |
|
import math |
|
|
|
image_format = "png" |
|
max_hw, min_hw = max(image.size), min(image.size) |
|
aspect_ratio = max_hw / min_hw |
|
max_len, min_len = 1024, 1024 |
|
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
|
longest_edge = int(shortest_edge * aspect_ratio) |
|
W, H = image.size |
|
if longest_edge != max(image.size): |
|
if H > W: |
|
H, W = longest_edge, shortest_edge |
|
else: |
|
H, W = shortest_edge, longest_edge |
|
image = image.resize((W, H)) |
|
|
|
image_bytes = BytesIO() |
|
image.save(image_bytes, format="PNG") |
|
if max_image_size_mb: |
|
target_size_bytes = max_image_size_mb * 1024 * 1024 |
|
|
|
current_size_bytes = image_bytes.tell() |
|
if current_size_bytes > target_size_bytes: |
|
resize_factor = (target_size_bytes / current_size_bytes) ** 0.5 |
|
new_width = math.floor(image.width * resize_factor) |
|
new_height = math.floor(image.height * resize_factor) |
|
image = image.resize((new_width, new_height)) |
|
|
|
image_bytes = BytesIO() |
|
image.save(image_bytes, format="PNG") |
|
current_size_bytes = image_bytes.tell() |
|
|
|
image_bytes.seek(0) |
|
|
|
return image_format, image_bytes |
|
|
|
def convert_url_to_image_bytes(self, max_image_size_mb): |
|
from PIL import Image |
|
|
|
if self.url.endswith(".svg"): |
|
import cairosvg |
|
|
|
with open(self.url, "rb") as svg_file: |
|
svg_data = svg_file.read() |
|
|
|
png_data = cairosvg.svg2png(bytestring=svg_data) |
|
pil_image = Image.open(BytesIO(png_data)).convert("RGBA") |
|
else: |
|
pil_image = Image.open(self.url).convert("RGBA") |
|
|
|
image_format, image_bytes = self.resize_image_and_return_image_in_bytes( |
|
pil_image, max_image_size_mb |
|
) |
|
|
|
img_base64_str = base64.b64encode(image_bytes.getvalue()).decode() |
|
|
|
return image_format, img_base64_str |
|
|
|
def to_conversation_format(self, max_image_size_mb): |
|
image_format, image_bytes = self.convert_url_to_image_bytes( |
|
max_image_size_mb=max_image_size_mb |
|
) |
|
|
|
self.filetype = image_format |
|
self.image_format = ImageFormat.BYTES |
|
self.base64_str = image_bytes |
|
|
|
return self |
|
|
|
|
|
if __name__ == "__main__": |
|
image = Image(url="fastchat/serve/example_images/fridge.jpg") |
|
image.to_conversation_format(max_image_size_mb=5 / 1.5) |
|
|
|
json_str = image.model_dump_json() |
|
print(json_str) |
|
|