Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -36,44 +36,48 @@ def encode_image_to_base64(image):
|
|
36 |
image.save(buffered, format="PNG")
|
37 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
38 |
|
|
|
39 |
def analyze_image(image):
|
40 |
client = OpenAI(api_key=OPENAI_API_KEY)
|
41 |
base64_image = encode_image_to_base64(image)
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
{
|
45 |
-
"
|
46 |
-
"
|
47 |
-
{
|
48 |
-
|
49 |
-
"text": """Your task is to determine if the image is surprising or not surprising.
|
50 |
-
if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
|
51 |
-
Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
|
52 |
-
Provide the response as a JSON with the following structure:
|
53 |
-
{
|
54 |
-
"label": "[surprising OR not surprising]",
|
55 |
-
"element": "[element]",
|
56 |
-
"rating": [1-5]
|
57 |
-
}"""
|
58 |
-
},
|
59 |
-
{
|
60 |
-
"type": "image_url",
|
61 |
-
"image_url": {
|
62 |
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
63 |
-
}
|
64 |
-
}
|
65 |
-
]
|
66 |
}
|
67 |
]
|
68 |
|
|
|
|
|
|
|
69 |
response = client.chat.completions.create(
|
70 |
model="gpt-4o-mini",
|
71 |
-
messages=
|
|
|
|
|
|
|
|
|
|
|
72 |
max_tokens=100,
|
73 |
temperature=0.1,
|
74 |
-
response_format={
|
75 |
-
"type": "json_object"
|
76 |
-
}
|
77 |
)
|
78 |
|
79 |
return response.choices[0].message.content
|
@@ -102,7 +106,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
102 |
original_size = image.size
|
103 |
|
104 |
# Calculate relative font size based on image dimensions
|
105 |
-
base_fontsize = min(original_size) / 40 # Adjust this divisor
|
106 |
|
107 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
108 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
@@ -125,7 +129,6 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
125 |
|
126 |
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
127 |
fig.add_axes(ax)
|
128 |
-
|
129 |
plt.imshow(image)
|
130 |
|
131 |
scores = results["scores"]
|
@@ -154,18 +157,18 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
154 |
mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
|
155 |
show_mask(mask, ax=ax)
|
156 |
|
157 |
-
# Draw rectangle
|
158 |
rect = patches.Rectangle(
|
159 |
(box[0], box[1]),
|
160 |
box[2] - box[0],
|
161 |
box[3] - box[1],
|
162 |
-
linewidth=max(2, min(original_size) / 500),
|
163 |
edgecolor='red',
|
164 |
facecolor='none'
|
165 |
)
|
166 |
ax.add_patch(rect)
|
167 |
|
168 |
-
#
|
169 |
plt.text(
|
170 |
box[0], box[1] - base_fontsize,
|
171 |
f'{max_score:.2f}',
|
@@ -175,7 +178,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
175 |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
|
176 |
)
|
177 |
|
178 |
-
#
|
179 |
plt.text(
|
180 |
box[2] + base_fontsize / 2, box[1],
|
181 |
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
|
@@ -188,18 +191,20 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
188 |
|
189 |
plt.axis('off')
|
190 |
|
191 |
-
# Save
|
192 |
buf = io.BytesIO()
|
193 |
-
plt.savefig(
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
199 |
buf.seek(0)
|
200 |
plt.close()
|
201 |
|
202 |
-
#
|
203 |
output_image = Image.open(buf)
|
204 |
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
205 |
|
@@ -220,29 +225,34 @@ def process_and_analyze(image):
|
|
220 |
try:
|
221 |
# Handle different input types
|
222 |
if isinstance(image, tuple):
|
223 |
-
image = image[0]
|
224 |
if isinstance(image, np.ndarray):
|
225 |
image = Image.fromarray(image)
|
226 |
if not isinstance(image, Image.Image):
|
227 |
raise ValueError("Invalid image format")
|
228 |
|
229 |
-
# Analyze image
|
230 |
gpt_response = analyze_image(image)
|
231 |
response_data = json.loads(gpt_response)
|
232 |
|
|
|
233 |
if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
|
234 |
result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
|
235 |
result_image = Image.open(result_buf)
|
236 |
-
analysis_text =
|
|
|
|
|
|
|
|
|
237 |
return result_image, analysis_text
|
238 |
else:
|
|
|
239 |
return image, "Not Surprising"
|
240 |
|
241 |
except Exception as e:
|
242 |
return None, f"Error processing image: {str(e)}"
|
243 |
|
244 |
|
245 |
-
# Create Gradio interface
|
246 |
def create_interface():
|
247 |
with gr.Blocks() as demo:
|
248 |
gr.Markdown("# Image Surprise Analysis")
|
@@ -267,4 +277,4 @@ def create_interface():
|
|
267 |
|
268 |
if __name__ == "__main__":
|
269 |
demo = create_interface()
|
270 |
-
demo.launch()
|
|
|
36 |
image.save(buffered, format="PNG")
|
37 |
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
38 |
|
39 |
+
|
40 |
def analyze_image(image):
|
41 |
client = OpenAI(api_key=OPENAI_API_KEY)
|
42 |
base64_image = encode_image_to_base64(image)
|
43 |
|
44 |
+
# Build the list-of-dicts prompt:
|
45 |
+
prompt_dict = [
|
46 |
+
{
|
47 |
+
"type": "text",
|
48 |
+
"text": """Your task is to determine if the image is surprising or not surprising.
|
49 |
+
If the image is surprising, determine which element, figure, or object in the image is making the image surprising and write it only in one sentence with no more than 6 words.
|
50 |
+
Otherwise, write 'NA'.
|
51 |
+
Also, rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
|
52 |
+
Provide the response as a JSON with the following structure:
|
53 |
+
{
|
54 |
+
"label": "[surprising OR not surprising]",
|
55 |
+
"element": "[element]",
|
56 |
+
"rating": [1-5]
|
57 |
+
}"""
|
58 |
+
},
|
59 |
{
|
60 |
+
"type": "image_url",
|
61 |
+
"image_url": {
|
62 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
63 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
]
|
66 |
|
67 |
+
# JSON-encode the list so "content" is a string
|
68 |
+
json_prompt = json.dumps(prompt_dict)
|
69 |
+
|
70 |
response = client.chat.completions.create(
|
71 |
model="gpt-4o-mini",
|
72 |
+
messages=[
|
73 |
+
{
|
74 |
+
"role": "user",
|
75 |
+
"content": json_prompt, # must be a string
|
76 |
+
}
|
77 |
+
],
|
78 |
max_tokens=100,
|
79 |
temperature=0.1,
|
80 |
+
response_format={"type": "json_object"}
|
|
|
|
|
81 |
)
|
82 |
|
83 |
return response.choices[0].message.content
|
|
|
106 |
original_size = image.size
|
107 |
|
108 |
# Calculate relative font size based on image dimensions
|
109 |
+
base_fontsize = min(original_size) / 40 # Adjust this divisor as needed
|
110 |
|
111 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
112 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
|
|
129 |
|
130 |
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
131 |
fig.add_axes(ax)
|
|
|
132 |
plt.imshow(image)
|
133 |
|
134 |
scores = results["scores"]
|
|
|
157 |
mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
|
158 |
show_mask(mask, ax=ax)
|
159 |
|
160 |
+
# Draw rectangle around the detected area
|
161 |
rect = patches.Rectangle(
|
162 |
(box[0], box[1]),
|
163 |
box[2] - box[0],
|
164 |
box[3] - box[1],
|
165 |
+
linewidth=max(2, min(original_size) / 500),
|
166 |
edgecolor='red',
|
167 |
facecolor='none'
|
168 |
)
|
169 |
ax.add_patch(rect)
|
170 |
|
171 |
+
# Confidence score
|
172 |
plt.text(
|
173 |
box[0], box[1] - base_fontsize,
|
174 |
f'{max_score:.2f}',
|
|
|
178 |
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
|
179 |
)
|
180 |
|
181 |
+
# Label + rating
|
182 |
plt.text(
|
183 |
box[2] + base_fontsize / 2, box[1],
|
184 |
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
|
|
|
191 |
|
192 |
plt.axis('off')
|
193 |
|
194 |
+
# Save figure to buffer
|
195 |
buf = io.BytesIO()
|
196 |
+
plt.savefig(
|
197 |
+
buf,
|
198 |
+
format='png',
|
199 |
+
dpi=dpi,
|
200 |
+
bbox_inches='tight',
|
201 |
+
pad_inches=0,
|
202 |
+
metadata={'dpi': original_dpi}
|
203 |
+
)
|
204 |
buf.seek(0)
|
205 |
plt.close()
|
206 |
|
207 |
+
# Convert buffer back to PIL
|
208 |
output_image = Image.open(buf)
|
209 |
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
210 |
|
|
|
225 |
try:
|
226 |
# Handle different input types
|
227 |
if isinstance(image, tuple):
|
228 |
+
image = image[0]
|
229 |
if isinstance(image, np.ndarray):
|
230 |
image = Image.fromarray(image)
|
231 |
if not isinstance(image, Image.Image):
|
232 |
raise ValueError("Invalid image format")
|
233 |
|
234 |
+
# Analyze image with GPT
|
235 |
gpt_response = analyze_image(image)
|
236 |
response_data = json.loads(gpt_response)
|
237 |
|
238 |
+
# If surprising, try to detect the element
|
239 |
if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
|
240 |
result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
|
241 |
result_image = Image.open(result_buf)
|
242 |
+
analysis_text = (
|
243 |
+
f"Label: {response_data['label']}\n"
|
244 |
+
f"Element: {response_data['element']}\n"
|
245 |
+
f"Rating: {response_data['rating']}/5"
|
246 |
+
)
|
247 |
return result_image, analysis_text
|
248 |
else:
|
249 |
+
# If not surprising or element=NA
|
250 |
return image, "Not Surprising"
|
251 |
|
252 |
except Exception as e:
|
253 |
return None, f"Error processing image: {str(e)}"
|
254 |
|
255 |
|
|
|
256 |
def create_interface():
|
257 |
with gr.Blocks() as demo:
|
258 |
gr.Markdown("# Image Surprise Analysis")
|
|
|
277 |
|
278 |
if __name__ == "__main__":
|
279 |
demo = create_interface()
|
280 |
+
demo.launch()
|