reab5555 commited on
Commit
ec4100b
·
verified ·
1 Parent(s): b7b543a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -87,9 +87,12 @@ def process_image_detection(image, target_label, surprise_rating):
87
  device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
  # Get original image DPI and size
90
- original_dpi = image.info.get('dpi', (72, 72)) # Default to 72 DPI if not specified
91
  original_size = image.size
92
 
 
 
 
93
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
94
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14").to(device)
95
 
@@ -105,12 +108,10 @@ def process_image_detection(image, target_label, surprise_rating):
105
  target_sizes = torch.tensor([image.size[::-1]]).to(device)
106
  results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
107
 
108
- # Create figure with the exact pixel size of the original image
109
- dpi = 100 # Base DPI for calculation
110
  figsize = (original_size[0] / dpi, original_size[1] / dpi)
111
  fig = plt.figure(figsize=figsize, dpi=dpi)
112
 
113
- # Remove margins and spacing
114
  ax = plt.Axes(fig, [0., 0., 1., 1.])
115
  fig.add_axes(ax)
116
 
@@ -142,47 +143,55 @@ def process_image_detection(image, target_label, surprise_rating):
142
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
143
  show_mask(mask, ax=ax)
144
 
 
145
  rect = patches.Rectangle(
146
  (box[0], box[1]),
147
  box[2] - box[0],
148
  box[3] - box[1],
149
- linewidth=2,
150
  edgecolor='red',
151
  facecolor='none'
152
  )
153
  ax.add_patch(rect)
154
 
 
155
  plt.text(
156
- box[0], box[1] - 5,
157
  f'{max_score:.2f}',
158
- color='red'
 
 
 
159
  )
160
 
 
161
  plt.text(
162
- box[2] + 5, box[1],
163
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
164
  color='red',
165
- fontsize=10,
 
 
166
  verticalalignment='bottom'
167
  )
168
 
169
  plt.axis('off')
170
 
171
- # Save with original resolution and DPI
172
  buf = io.BytesIO()
173
  plt.savefig(buf,
174
  format='png',
175
  dpi=dpi,
176
  bbox_inches='tight',
177
- pad_inches=0)
 
178
  buf.seek(0)
179
  plt.close()
180
 
181
- # Open the buffer and create a new image with original properties
182
  output_image = Image.open(buf)
183
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
184
 
185
- # Create a new buffer with the properly sized image
186
  final_buf = io.BytesIO()
187
  output_image.save(final_buf, format='PNG', dpi=original_dpi)
188
  final_buf.seek(0)
 
87
  device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
  # Get original image DPI and size
90
+ original_dpi = image.info.get('dpi', (72, 72))
91
  original_size = image.size
92
 
93
+ # Calculate relative font size based on image dimensions
94
+ base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
95
+
96
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
97
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14").to(device)
98
 
 
108
  target_sizes = torch.tensor([image.size[::-1]]).to(device)
109
  results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
110
 
111
+ dpi = 300 # Increased DPI for better text rendering
 
112
  figsize = (original_size[0] / dpi, original_size[1] / dpi)
113
  fig = plt.figure(figsize=figsize, dpi=dpi)
114
 
 
115
  ax = plt.Axes(fig, [0., 0., 1., 1.])
116
  fig.add_axes(ax)
117
 
 
143
  mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
144
  show_mask(mask, ax=ax)
145
 
146
+ # Draw rectangle with increased line width
147
  rect = patches.Rectangle(
148
  (box[0], box[1]),
149
  box[2] - box[0],
150
  box[3] - box[1],
151
+ linewidth=max(2, min(original_size) / 500), # Scale line width with image size
152
  edgecolor='red',
153
  facecolor='none'
154
  )
155
  ax.add_patch(rect)
156
 
157
+ # Add confidence score with improved visibility
158
  plt.text(
159
+ box[0], box[1] - base_fontsize,
160
  f'{max_score:.2f}',
161
+ color='red',
162
+ fontsize=base_fontsize,
163
+ fontweight='bold',
164
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
165
  )
166
 
167
+ # Add label and rating with improved visibility
168
  plt.text(
169
+ box[2] + base_fontsize / 2, box[1],
170
  f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
171
  color='red',
172
+ fontsize=base_fontsize,
173
+ fontweight='bold',
174
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
175
  verticalalignment='bottom'
176
  )
177
 
178
  plt.axis('off')
179
 
180
+ # Save with high DPI
181
  buf = io.BytesIO()
182
  plt.savefig(buf,
183
  format='png',
184
  dpi=dpi,
185
  bbox_inches='tight',
186
+ pad_inches=0,
187
+ metadata={'dpi': original_dpi})
188
  buf.seek(0)
189
  plt.close()
190
 
191
+ # Process final image
192
  output_image = Image.open(buf)
193
  output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
194
 
 
195
  final_buf = io.BytesIO()
196
  output_image.save(final_buf, format='PNG', dpi=original_dpi)
197
  final_buf.seek(0)