shivalikasingh commited on
Commit
43811b3
·
verified ·
1 Parent(s): 2fa14b4

Create aya_vision_utils.py

Browse files
Files changed (1) hide show
  1. aya_vision_utils.py +109 -0
aya_vision_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
2
+ from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
3
+ import base64
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import logging
7
+ import cohere
8
+ import os
9
+ import traceback
10
+ import random
11
+ import gradio as gr
12
+
13
+ # from dotenv import load_dotenv
14
+ # load_dotenv()
15
+
16
+ MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ aya_vision_client = cohere.ClientV2(
21
+ api_key=MULTIMODAL_API_KEY,
22
+ client_name="c4ai-aya-vision-hf-space"
23
+ )
24
+
25
+ def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
26
+ response = aya_vision_client.chat(
27
+ messages=chat_history,
28
+ model=model,
29
+ )
30
+ return response.message.content[0].text
31
+
32
+
33
+ def get_aya_vision_prompt_example(language):
34
+ example = AYA_VISION_PROMPT_EXAMPLES[language]
35
+ print("example:", example)
36
+ print("example prompt:", example[0])
37
+ print("example image:", example[1])
38
+ return example[0], example[1]
39
+
40
+ def get_base64_from_local_file(file_path):
41
+ try:
42
+ print("loading image")
43
+ with open(file_path, "rb") as image_file:
44
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
45
+ print("converted image")
46
+ return base64_image
47
+ except Exception as e:
48
+ logger.debug(f"Error converting local image to base64 string: {e}")
49
+ return None
50
+
51
+
52
+ def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
53
+ print("incoming message:", incoming_message)
54
+ print("image_filepath:", image_filepath)
55
+ max_size_bytes = max_size_mb * 1024 * 1024
56
+
57
+ if image_filepath.endswith(".jpg") or image_filepath.endswith('.jpeg'):
58
+ image_type="image/jpeg"
59
+ elif image_filepath.endswith(".png"):
60
+ image_type = "image/png"
61
+ elif image_filepath.endswith(".webp"):
62
+ image_type="image/webp"
63
+ elif image_filepath.endswith(".gif"):
64
+ image_type="image/gif"
65
+
66
+ response=""
67
+ chat_history = []
68
+ print("converting image to base 64")
69
+ base64_image = get_base64_from_local_file(image_filepath)
70
+ image = f"data:{image_type};base64,{base64_image}"
71
+ print("Image base64:", image[:30])
72
+
73
+ # to prevent Cohere API from throwing error for empty message
74
+ if incoming_message=="" or incoming_message is None:
75
+ incoming_message="."
76
+
77
+ chat_history.append(
78
+ {
79
+ "role": "user",
80
+ "content": [{"type": "text", "text": incoming_message},
81
+ {"type": "image_url","image_url": { "url": image}}],
82
+ }
83
+ )
84
+
85
+ image_size_bytes = get_base64_image_size(image)
86
+ if image_size_bytes >= max_size_bytes:
87
+ gr.Error("Please upload image with size under 5MB")
88
+
89
+ # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
90
+ # return response
91
+ res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
92
+ output = ""
93
+
94
+ for event in res:
95
+ if event:
96
+ if event.type == "content-delta":
97
+ output += event.delta.message.content.text
98
+ yield output
99
+
100
+ def get_base64_image_size(base64_string):
101
+ if ',' in base64_string:
102
+ base64_data = base64_string.split(',', 1)[1]
103
+ else:
104
+ base64_data = base64_string
105
+
106
+ base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
107
+ padding = base64_data.count('=')
108
+ size_bytes = (len(base64_data) * 3) // 4 - padding
109
+ return size_bytes