Spaces:
Running
on
T4
Running
on
T4
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT | |
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import logging | |
import cohere | |
import os | |
import traceback | |
import random | |
import gradio as gr | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY") | |
logger = logging.getLogger(__name__) | |
aya_vision_client = cohere.ClientV2( | |
api_key=MULTIMODAL_API_KEY, | |
client_name="c4ai-aya-vision-hf-space" | |
) | |
def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME): | |
response = aya_vision_client.chat( | |
messages=chat_history, | |
model=model, | |
) | |
return response.message.content[0].text | |
def get_aya_vision_prompt_example(language): | |
example = AYA_VISION_PROMPT_EXAMPLES[language] | |
print("example:", example) | |
print("example prompt:", example[0]) | |
print("example image:", example[1]) | |
return example[0], example[1] | |
def get_base64_from_local_file(file_path): | |
try: | |
print("loading image") | |
with open(file_path, "rb") as image_file: | |
base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
print("converted image") | |
return base64_image | |
except Exception as e: | |
logger.debug(f"Error converting local image to base64 string: {e}") | |
return None | |
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5): | |
print("incoming message:", incoming_message) | |
print("image_filepath:", image_filepath) | |
max_size_bytes = max_size_mb * 1024 * 1024 | |
image_ext = image_filepath.lower() | |
if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'): | |
image_type="image/jpeg" | |
elif image_ext.endswith(".png"): | |
image_type = "image/png" | |
elif image_ext.endswith(".webp"): | |
image_type="image/webp" | |
elif image_ext.endswith(".gif"): | |
image_type="image/gif" | |
response="" | |
chat_history = [] | |
print("converting image to base 64") | |
base64_image = get_base64_from_local_file(image_filepath) | |
image = f"data:{image_type};base64,{base64_image}" | |
print("Image base64:", image[:30]) | |
# to prevent Cohere API from throwing error for empty message | |
if incoming_message=="" or incoming_message is None: | |
incoming_message="." | |
chat_history.append( | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": incoming_message}, | |
{"type": "image_url","image_url": { "url": image}}], | |
} | |
) | |
image_size_bytes = get_base64_image_size(image) | |
if image_size_bytes >= max_size_bytes: | |
gr.Error("Please upload image with size under 5MB") | |
# response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME) | |
# return response | |
res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME) | |
output = "" | |
for event in res: | |
if event: | |
if event.type == "content-delta": | |
output += event.delta.message.content.text | |
yield output | |
def get_base64_image_size(base64_string): | |
if ',' in base64_string: | |
base64_data = base64_string.split(',', 1)[1] | |
else: | |
base64_data = base64_string | |
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '') | |
padding = base64_data.count('=') | |
size_bytes = (len(base64_data) * 3) // 4 - padding | |
return size_bytes |