reab5555 commited on
Commit
3da6b4f
·
verified ·
1 Parent(s): a27e6f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -45
app.py CHANGED
@@ -36,44 +36,48 @@ def encode_image_to_base64(image):
36
  image.save(buffered, format="PNG")
37
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
38
 
 
39
  def analyze_image(image):
40
  client = OpenAI(api_key=OPENAI_API_KEY)
41
  base64_image = encode_image_to_base64(image)
42
 
43
- messages = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  {
45
- "role": "user",
46
- "content": [
47
- {
48
- "type": "text",
49
- "text": """Your task is to determine if the image is surprising or not surprising.
50
- if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
51
- Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
52
- Provide the response as a JSON with the following structure:
53
- {
54
- "label": "[surprising OR not surprising]",
55
- "element": "[element]",
56
- "rating": [1-5]
57
- }"""
58
- },
59
- {
60
- "type": "image_url",
61
- "image_url": {
62
- "url": f"data:image/jpeg;base64,{base64_image}"
63
- }
64
- }
65
- ]
66
  }
67
  ]
68
 
 
 
 
69
  response = client.chat.completions.create(
70
  model="gpt-4o-mini",
71
- messages=messages,
 
 
 
 
 
72
  max_tokens=100,
73
  temperature=0.1,
74
- response_format={
75
- "type": "json_object"
76
- }
77
  )
78
 
79
  return response.choices[0].message.content
@@ -102,7 +106,7 @@ def process_image_detection(image, target_label, surprise_rating):
102
  original_size = image.size
103
 
104
  # Calculate relative font size based on image dimensions
105
- base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
106
 
107
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
108
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
@@ -125,7 +129,6 @@ def process_image_detection(image, target_label, surprise_rating):
125
 
126
  ax = plt.Axes(fig, [0., 0., 1., 1.])
127
  fig.add_axes(ax)
128
-
129
  plt.imshow(image)
130
 
131
  scores = results["scores"]
@@ -154,18 +157,18 @@ def process_image_detection(image, target_label, surprise_rating):
154
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
155
  show_mask(mask, ax=ax)
156
 
157
- # Draw rectangle with increased line width
158
  rect = patches.Rectangle(
159
  (box[0], box[1]),
160
  box[2] - box[0],
161
  box[3] - box[1],
162
- linewidth=max(2, min(original_size) / 500), # Scale line width with image size
163
  edgecolor='red',
164
  facecolor='none'
165
  )
166
  ax.add_patch(rect)
167
 
168
- # Add confidence score with improved visibility
169
  plt.text(
170
  box[0], box[1] - base_fontsize,
171
  f'{max_score:.2f}',
@@ -175,7 +178,7 @@ def process_image_detection(image, target_label, surprise_rating):
175
  bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
176
  )
177
 
178
- # Add label and rating with improved visibility
179
  plt.text(
180
  box[2] + base_fontsize / 2, box[1],
181
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
@@ -188,18 +191,20 @@ def process_image_detection(image, target_label, surprise_rating):
188
 
189
  plt.axis('off')
190
 
191
- # Save with high DPI
192
  buf = io.BytesIO()
193
- plt.savefig(buf,
194
- format='png',
195
- dpi=dpi,
196
- bbox_inches='tight',
197
- pad_inches=0,
198
- metadata={'dpi': original_dpi})
 
 
199
  buf.seek(0)
200
  plt.close()
201
 
202
- # Process final image
203
  output_image = Image.open(buf)
204
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
205
 
@@ -220,29 +225,34 @@ def process_and_analyze(image):
220
  try:
221
  # Handle different input types
222
  if isinstance(image, tuple):
223
- image = image[0] # Take the first element if it's a tuple
224
  if isinstance(image, np.ndarray):
225
  image = Image.fromarray(image)
226
  if not isinstance(image, Image.Image):
227
  raise ValueError("Invalid image format")
228
 
229
- # Analyze image
230
  gpt_response = analyze_image(image)
231
  response_data = json.loads(gpt_response)
232
 
 
233
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
234
  result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
235
  result_image = Image.open(result_buf)
236
- analysis_text = f"Label: {response_data['label']}\nElement: {response_data['element']}\nRating: {response_data['rating']}/5"
 
 
 
 
237
  return result_image, analysis_text
238
  else:
 
239
  return image, "Not Surprising"
240
 
241
  except Exception as e:
242
  return None, f"Error processing image: {str(e)}"
243
 
244
 
245
- # Create Gradio interface
246
  def create_interface():
247
  with gr.Blocks() as demo:
248
  gr.Markdown("# Image Surprise Analysis")
@@ -267,4 +277,4 @@ def create_interface():
267
 
268
  if __name__ == "__main__":
269
  demo = create_interface()
270
- demo.launch()
 
36
  image.save(buffered, format="PNG")
37
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
38
 
39
+
40
  def analyze_image(image):
41
  client = OpenAI(api_key=OPENAI_API_KEY)
42
  base64_image = encode_image_to_base64(image)
43
 
44
+ # Build the list-of-dicts prompt:
45
+ prompt_dict = [
46
+ {
47
+ "type": "text",
48
+ "text": """Your task is to determine if the image is surprising or not surprising.
49
+ If the image is surprising, determine which element, figure, or object in the image is making the image surprising and write it only in one sentence with no more than 6 words.
50
+ Otherwise, write 'NA'.
51
+ Also, rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
52
+ Provide the response as a JSON with the following structure:
53
+ {
54
+ "label": "[surprising OR not surprising]",
55
+ "element": "[element]",
56
+ "rating": [1-5]
57
+ }"""
58
+ },
59
  {
60
+ "type": "image_url",
61
+ "image_url": {
62
+ "url": f"data:image/jpeg;base64,{base64_image}"
63
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
  ]
66
 
67
+ # JSON-encode the list so "content" is a string
68
+ json_prompt = json.dumps(prompt_dict)
69
+
70
  response = client.chat.completions.create(
71
  model="gpt-4o-mini",
72
+ messages=[
73
+ {
74
+ "role": "user",
75
+ "content": json_prompt, # must be a string
76
+ }
77
+ ],
78
  max_tokens=100,
79
  temperature=0.1,
80
+ response_format={"type": "json_object"}
 
 
81
  )
82
 
83
  return response.choices[0].message.content
 
106
  original_size = image.size
107
 
108
  # Calculate relative font size based on image dimensions
109
+ base_fontsize = min(original_size) / 40 # Adjust this divisor as needed
110
 
111
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
112
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
 
129
 
130
  ax = plt.Axes(fig, [0., 0., 1., 1.])
131
  fig.add_axes(ax)
 
132
  plt.imshow(image)
133
 
134
  scores = results["scores"]
 
157
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
158
  show_mask(mask, ax=ax)
159
 
160
+ # Draw rectangle around the detected area
161
  rect = patches.Rectangle(
162
  (box[0], box[1]),
163
  box[2] - box[0],
164
  box[3] - box[1],
165
+ linewidth=max(2, min(original_size) / 500),
166
  edgecolor='red',
167
  facecolor='none'
168
  )
169
  ax.add_patch(rect)
170
 
171
+ # Confidence score
172
  plt.text(
173
  box[0], box[1] - base_fontsize,
174
  f'{max_score:.2f}',
 
178
  bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
179
  )
180
 
181
+ # Label + rating
182
  plt.text(
183
  box[2] + base_fontsize / 2, box[1],
184
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
 
191
 
192
  plt.axis('off')
193
 
194
+ # Save figure to buffer
195
  buf = io.BytesIO()
196
+ plt.savefig(
197
+ buf,
198
+ format='png',
199
+ dpi=dpi,
200
+ bbox_inches='tight',
201
+ pad_inches=0,
202
+ metadata={'dpi': original_dpi}
203
+ )
204
  buf.seek(0)
205
  plt.close()
206
 
207
+ # Convert buffer back to PIL
208
  output_image = Image.open(buf)
209
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
210
 
 
225
  try:
226
  # Handle different input types
227
  if isinstance(image, tuple):
228
+ image = image[0]
229
  if isinstance(image, np.ndarray):
230
  image = Image.fromarray(image)
231
  if not isinstance(image, Image.Image):
232
  raise ValueError("Invalid image format")
233
 
234
+ # Analyze image with GPT
235
  gpt_response = analyze_image(image)
236
  response_data = json.loads(gpt_response)
237
 
238
+ # If surprising, try to detect the element
239
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
240
  result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
241
  result_image = Image.open(result_buf)
242
+ analysis_text = (
243
+ f"Label: {response_data['label']}\n"
244
+ f"Element: {response_data['element']}\n"
245
+ f"Rating: {response_data['rating']}/5"
246
+ )
247
  return result_image, analysis_text
248
  else:
249
+ # If not surprising or element=NA
250
  return image, "Not Surprising"
251
 
252
  except Exception as e:
253
  return None, f"Error processing image: {str(e)}"
254
 
255
 
 
256
  def create_interface():
257
  with gr.Blocks() as demo:
258
  gr.Markdown("# Image Surprise Analysis")
 
277
 
278
  if __name__ == "__main__":
279
  demo = create_interface()
280
+ demo.launch()