File size: 3,621 Bytes
43811b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd98abf
 
 
43811b3
cd98abf
43811b3
cd98abf
43811b3
cd98abf
43811b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
import base64
from io import BytesIO
from PIL import Image
import logging
import cohere
import os
import traceback
import random
import gradio as gr

# from dotenv import load_dotenv
# load_dotenv()

MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")

logger = logging.getLogger(__name__)

aya_vision_client = cohere.ClientV2(
    api_key=MULTIMODAL_API_KEY,
    client_name="c4ai-aya-vision-hf-space"
)                                    

def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
    response = aya_vision_client.chat(
        messages=chat_history,
        model=model,
        )
    return response.message.content[0].text  


def get_aya_vision_prompt_example(language):
    example = AYA_VISION_PROMPT_EXAMPLES[language]
    print("example:", example)
    print("example prompt:", example[0])
    print("example image:", example[1])
    return example[0], example[1]                                 

def get_base64_from_local_file(file_path):
    try:
        print("loading image")
        with open(file_path, "rb") as image_file:
            base64_image = base64.b64encode(image_file.read()).decode('utf-8')
            print("converted image")
        return base64_image
    except Exception as e:
        logger.debug(f"Error converting local image to base64 string: {e}")
        return None


def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
    print("incoming message:", incoming_message)
    print("image_filepath:", image_filepath)
    max_size_bytes = max_size_mb * 1024 * 1024

    image_ext = image_filepath.lower()

    if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'):
        image_type="image/jpeg"
    elif image_ext.endswith(".png"):
        image_type =  "image/png"
    elif image_ext.endswith(".webp"):
        image_type="image/webp"
    elif image_ext.endswith(".gif"):
        image_type="image/gif"

    response=""
    chat_history = []
    print("converting image to base 64")
    base64_image = get_base64_from_local_file(image_filepath)
    image = f"data:{image_type};base64,{base64_image}"
    print("Image base64:", image[:30])

    # to prevent Cohere API from throwing error for empty message
    if incoming_message=="" or incoming_message is None:
        incoming_message="."

    chat_history.append(
        {
        "role": "user",
        "content": [{"type": "text", "text": incoming_message},
            {"type": "image_url","image_url": { "url": image}}],
        }
    )

    image_size_bytes = get_base64_image_size(image)
    if image_size_bytes >= max_size_bytes:
        gr.Error("Please upload image with size under 5MB")
    
    # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
    # return response
    res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
    output = ""

    for event in res:
        if event:
            if event.type == "content-delta":
                output += event.delta.message.content.text
                yield output

def get_base64_image_size(base64_string):
    if ',' in base64_string:
        base64_data = base64_string.split(',', 1)[1]
    else:
        base64_data = base64_string

    base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
    padding = base64_data.count('=')
    size_bytes = (len(base64_data) * 3) // 4 - padding
    return size_bytes