reab5555 commited on
Commit
f8ce6ee
·
verified ·
1 Parent(s): f57fde9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -61
app.py CHANGED
@@ -19,48 +19,42 @@ load_dotenv()
19
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
 
21
 
22
- def resize_and_compress(image, max_width=800, max_height=800, quality=50):
23
- """Resize (if > max_width/height) and compress the image to keep Base64 under ~1MB."""
 
 
 
 
 
 
 
 
24
  if not isinstance(image, Image.Image):
25
- raise ValueError("Input must be a PIL Image")
26
-
27
- width, height = image.size
28
- if width > max_width or height > max_height:
29
- aspect_ratio = width / height
30
- if aspect_ratio > 1:
31
- new_width = max_width
32
- new_height = int(new_width / aspect_ratio)
33
- else:
34
- new_height = max_height
35
- new_width = int(new_height * aspect_ratio)
36
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
37
 
38
  buffered = io.BytesIO()
39
- # Save as JPEG with reduced quality
40
- image.save(buffered, format="JPEG", quality=quality)
41
- buffered.seek(0)
42
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
43
 
 
44
  def analyze_image(image):
45
  client = OpenAI(api_key=OPENAI_API_KEY)
 
46
 
47
- # Step 1: Resize + compress to keep the Base64 string under 1 MB
48
- base64_image = resize_and_compress(image, max_width=800, max_height=800, quality=50)
49
-
50
- # Build the list-of-dicts prompt
51
- prompt_dict = [
52
  {
53
  "type": "text",
54
- "text": """Your task is to determine if the image is surprising or not.
55
- If the image is surprising, which element is surprising (max 6 words).
56
- Otherwise, 'NA'. Also rate how surprising (1-5).
57
- Return JSON like:
58
  {
59
- "label": "[surprising or not surprising]",
60
- "element": "[element]",
61
- "rating": [1-5]
62
- }
63
- """
64
  },
65
  {
66
  "type": "image_url",
@@ -70,27 +64,29 @@ def analyze_image(image):
70
  }
71
  ]
72
 
73
- # JSON-encode to ensure content is a string
74
- json_prompt = json.dumps(prompt_dict)
 
 
 
 
 
 
 
75
 
76
- # Send request
77
  response = client.chat.completions.create(
78
- model="gpt-4o-mini",
79
- messages=[
80
- {
81
- "role": "user",
82
- "content": json_prompt
83
- }
84
- ],
85
  max_tokens=100,
86
  temperature=0.1,
87
- response_format={"type": "json_object"}
 
 
88
  )
89
 
90
  return response.choices[0].message.content
91
 
92
 
93
-
94
  def show_mask(mask, ax, random_color=False):
95
  if random_color:
96
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -114,7 +110,7 @@ def process_image_detection(image, target_label, surprise_rating):
114
  original_size = image.size
115
 
116
  # Calculate relative font size based on image dimensions
117
- base_fontsize = min(original_size) / 40 # Adjust this divisor as needed
118
 
119
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
120
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
@@ -137,6 +133,7 @@ def process_image_detection(image, target_label, surprise_rating):
137
 
138
  ax = plt.Axes(fig, [0., 0., 1., 1.])
139
  fig.add_axes(ax)
 
140
  plt.imshow(image)
141
 
142
  scores = results["scores"]
@@ -165,7 +162,7 @@ def process_image_detection(image, target_label, surprise_rating):
165
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
166
  show_mask(mask, ax=ax)
167
 
168
- # Draw rectangle around the detected area
169
  rect = patches.Rectangle(
170
  (box[0], box[1]),
171
  box[2] - box[0],
@@ -176,7 +173,7 @@ def process_image_detection(image, target_label, surprise_rating):
176
  )
177
  ax.add_patch(rect)
178
 
179
- # Confidence score
180
  plt.text(
181
  box[0], box[1] - base_fontsize,
182
  f'{max_score:.2f}',
@@ -186,7 +183,7 @@ def process_image_detection(image, target_label, surprise_rating):
186
  bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
187
  )
188
 
189
- # Label + rating
190
  plt.text(
191
  box[2] + base_fontsize / 2, box[1],
192
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
@@ -199,20 +196,17 @@ def process_image_detection(image, target_label, surprise_rating):
199
 
200
  plt.axis('off')
201
 
202
- # Save figure to buffer
203
  buf = io.BytesIO()
204
- plt.savefig(
205
- buf,
206
- format='png',
207
- dpi=dpi,
208
- bbox_inches='tight',
209
- pad_inches=0,
210
- metadata={'dpi': original_dpi}
211
- )
212
  buf.seek(0)
213
  plt.close()
214
 
215
- # Convert buffer back to PIL
216
  output_image = Image.open(buf)
217
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
218
 
@@ -233,17 +227,16 @@ def process_and_analyze(image):
233
  try:
234
  # Handle different input types
235
  if isinstance(image, tuple):
236
- image = image[0]
237
  if isinstance(image, np.ndarray):
238
  image = Image.fromarray(image)
239
  if not isinstance(image, Image.Image):
240
  raise ValueError("Invalid image format")
241
 
242
- # Analyze image with GPT
243
  gpt_response = analyze_image(image)
244
  response_data = json.loads(gpt_response)
245
 
246
- # If surprising, try to detect the element
247
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
248
  result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
249
  result_image = Image.open(result_buf)
@@ -254,7 +247,6 @@ def process_and_analyze(image):
254
  )
255
  return result_image, analysis_text
256
  else:
257
- # If not surprising or element=NA
258
  return image, "Not Surprising"
259
 
260
  except Exception as e:
 
19
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
 
21
 
22
+ def encode_image_to_base64(image):
23
+ # If image is a tuple (as sometimes provided by Gradio), take the first element
24
+ if isinstance(image, tuple):
25
+ image = image[0]
26
+
27
+ # If image is a numpy array, convert to PIL Image
28
+ if isinstance(image, np.ndarray):
29
+ image = Image.fromarray(image)
30
+
31
+ # Ensure image is in PIL Image format
32
  if not isinstance(image, Image.Image):
33
+ raise ValueError("Input must be a PIL Image, numpy array, or tuple containing an image")
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  buffered = io.BytesIO()
36
+ image.save(buffered, format="PNG")
 
 
37
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
38
 
39
+
40
  def analyze_image(image):
41
  client = OpenAI(api_key=OPENAI_API_KEY)
42
+ base64_image = encode_image_to_base64(image)
43
 
44
+ # --- MINIMAL FIX START ---
45
+ # We build a Python list of dicts, then JSON-encode it:
46
+ prompt_list = [
 
 
47
  {
48
  "type": "text",
49
+ "text": """Your task is to determine if the image is surprising or not surprising.
50
+ if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
51
+ Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
52
+ Provide the response as a JSON with the following structure:
53
  {
54
+ "label": "[surprising OR not surprising]",
55
+ "element": "[element]",
56
+ "rating": [1-5]
57
+ }"""
 
58
  },
59
  {
60
  "type": "image_url",
 
64
  }
65
  ]
66
 
67
+ prompt_json = json.dumps(prompt_list)
68
+
69
+ messages = [
70
+ {
71
+ "role": "user",
72
+ "content": prompt_json # content must be a single string
73
+ }
74
+ ]
75
+ # --- MINIMAL FIX END ---
76
 
 
77
  response = client.chat.completions.create(
78
+ model="gpt-4o-mini", # or whichever model you have access to
79
+ messages=messages,
 
 
 
 
 
80
  max_tokens=100,
81
  temperature=0.1,
82
+ response_format={
83
+ "type": "json_object"
84
+ }
85
  )
86
 
87
  return response.choices[0].message.content
88
 
89
 
 
90
  def show_mask(mask, ax, random_color=False):
91
  if random_color:
92
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
110
  original_size = image.size
111
 
112
  # Calculate relative font size based on image dimensions
113
+ base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
114
 
115
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
116
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
 
133
 
134
  ax = plt.Axes(fig, [0., 0., 1., 1.])
135
  fig.add_axes(ax)
136
+
137
  plt.imshow(image)
138
 
139
  scores = results["scores"]
 
162
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
163
  show_mask(mask, ax=ax)
164
 
165
+ # Draw rectangle with increased line width
166
  rect = patches.Rectangle(
167
  (box[0], box[1]),
168
  box[2] - box[0],
 
173
  )
174
  ax.add_patch(rect)
175
 
176
+ # Add confidence score with improved visibility
177
  plt.text(
178
  box[0], box[1] - base_fontsize,
179
  f'{max_score:.2f}',
 
183
  bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
184
  )
185
 
186
+ # Add label and rating with improved visibility
187
  plt.text(
188
  box[2] + base_fontsize / 2, box[1],
189
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
 
196
 
197
  plt.axis('off')
198
 
 
199
  buf = io.BytesIO()
200
+ plt.savefig(buf,
201
+ format='png',
202
+ dpi=dpi,
203
+ bbox_inches='tight',
204
+ pad_inches=0,
205
+ metadata={'dpi': original_dpi})
 
 
206
  buf.seek(0)
207
  plt.close()
208
 
209
+ # Process final image
210
  output_image = Image.open(buf)
211
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
212
 
 
227
  try:
228
  # Handle different input types
229
  if isinstance(image, tuple):
230
+ image = image[0] # Take the first element if it's a tuple
231
  if isinstance(image, np.ndarray):
232
  image = Image.fromarray(image)
233
  if not isinstance(image, Image.Image):
234
  raise ValueError("Invalid image format")
235
 
236
+ # Analyze image
237
  gpt_response = analyze_image(image)
238
  response_data = json.loads(gpt_response)
239
 
 
240
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
241
  result_buf = process_image_detection(image, response_data["element"], response_data["rating"])
242
  result_image = Image.open(result_buf)
 
247
  )
248
  return result_image, analysis_text
249
  else:
 
250
  return image, "Not Surprising"
251
 
252
  except Exception as e: