aya_expanse / aya_vision_utils.py
shivalikasingh's picture
Update aya_vision_utils.py
cd98abf verified
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
import base64
from io import BytesIO
from PIL import Image
import logging
import cohere
import os
import traceback
import random
import gradio as gr
# from dotenv import load_dotenv
# load_dotenv()
MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")
logger = logging.getLogger(__name__)
aya_vision_client = cohere.ClientV2(
api_key=MULTIMODAL_API_KEY,
client_name="c4ai-aya-vision-hf-space"
)
def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
response = aya_vision_client.chat(
messages=chat_history,
model=model,
)
return response.message.content[0].text
def get_aya_vision_prompt_example(language):
example = AYA_VISION_PROMPT_EXAMPLES[language]
print("example:", example)
print("example prompt:", example[0])
print("example image:", example[1])
return example[0], example[1]
def get_base64_from_local_file(file_path):
try:
print("loading image")
with open(file_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
print("converted image")
return base64_image
except Exception as e:
logger.debug(f"Error converting local image to base64 string: {e}")
return None
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
print("incoming message:", incoming_message)
print("image_filepath:", image_filepath)
max_size_bytes = max_size_mb * 1024 * 1024
image_ext = image_filepath.lower()
if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'):
image_type="image/jpeg"
elif image_ext.endswith(".png"):
image_type = "image/png"
elif image_ext.endswith(".webp"):
image_type="image/webp"
elif image_ext.endswith(".gif"):
image_type="image/gif"
response=""
chat_history = []
print("converting image to base 64")
base64_image = get_base64_from_local_file(image_filepath)
image = f"data:{image_type};base64,{base64_image}"
print("Image base64:", image[:30])
# to prevent Cohere API from throwing error for empty message
if incoming_message=="" or incoming_message is None:
incoming_message="."
chat_history.append(
{
"role": "user",
"content": [{"type": "text", "text": incoming_message},
{"type": "image_url","image_url": { "url": image}}],
}
)
image_size_bytes = get_base64_image_size(image)
if image_size_bytes >= max_size_bytes:
gr.Error("Please upload image with size under 5MB")
# response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
# return response
res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
output = ""
for event in res:
if event:
if event.type == "content-delta":
output += event.delta.message.content.text
yield output
def get_base64_image_size(base64_string):
if ',' in base64_string:
base64_data = base64_string.split(',', 1)[1]
else:
base64_data = base64_string
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
padding = base64_data.count('=')
size_bytes = (len(base64_data) * 3) // 4 - padding
return size_bytes