sohamnk commited on
Commit
4f9c351
Β·
verified Β·
1 Parent(s): b7e2432
Files changed (1) hide show
  1. app.py +28 -38
app.py CHANGED
@@ -27,8 +27,8 @@ FEATURE_WEIGHTS = {
27
  # ---- Load Models ----
28
  print("="*50)
29
  print("πŸš€ Initializing application and loading models...")
30
- # Set device to CPU for compatibility with Hugging Face Spaces free tier
31
- device = torch.device('cpu')
32
  print(f"🧠 Using device: {device}")
33
 
34
  print("...Loading Grounding DINO model...")
@@ -118,12 +118,7 @@ def extract_features(segmented_image: Image.Image) -> dict:
118
 
119
  color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
120
  cv2.normalize(color_hist, color_hist)
121
-
122
- # ------------------ THE FIX IS HERE ------------------
123
- # The color_hist is multi-dimensional. We must flatten it to a 1D array
124
- # before converting it to a list for the JSON response.
125
- flat_color_hist = color_hist.flatten()
126
- # ----------------------------------------------------
127
 
128
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
129
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
@@ -133,7 +128,7 @@ def extract_features(segmented_image: Image.Image) -> dict:
133
 
134
  return {
135
  "shape_features": hu_moments.tolist(),
136
- "color_features": flat_color_hist.tolist(), # Use the flattened array
137
  "texture_features": texture_hist.tolist()
138
  }
139
 
@@ -168,11 +163,12 @@ def process_item():
168
 
169
  object_name = data.get('objectName')
170
  description = data.get('objectDescription')
171
- image_url = data.get('objectImage')
172
 
173
  if not all([object_name, description]):
174
  return jsonify({"error": "objectName and objectDescription are required."}), 400
175
 
 
176
  canonical_label = get_canonical_label(object_name)
177
  text_embedding = get_text_embedding(description)
178
 
@@ -181,11 +177,13 @@ def process_item():
181
  "text_embedding": text_embedding,
182
  }
183
 
 
184
  if image_url:
185
  print("--- Image URL provided, processing visual features... ---")
186
  image = download_image_from_url(image_url)
187
  object_crop = detect_and_crop(image, canonical_label)
188
  visual_features = extract_features(object_crop)
 
189
  response_data.update(visual_features)
190
  else:
191
  print("--- No image URL provided, skipping visual feature extraction. ---")
@@ -215,12 +213,9 @@ def compare_items():
215
  return jsonify({"error": "queryItem and searchList are required."}), 400
216
 
217
  query_text_emb = np.array(query_item['text_embedding'])
218
- query_has_image = 'shape_features' in query_item and query_item['shape_features'] is not None
219
-
220
- if query_has_image:
221
- query_shape_feat = np.array(query_item['shape_features'])
222
- query_color_feat = np.array(query_item['color_features']).astype("float32")
223
- query_texture_feat = np.array(query_item['texture_features']).astype("float32")
224
 
225
  results = []
226
  print(f"--- Comparing 1 query item against {len(search_list)} items ---")
@@ -233,29 +228,24 @@ def compare_items():
233
  text_score = cosine_similarity(query_text_emb, text_emb_found)
234
  print(f" - Text Score: {text_score:.4f}")
235
 
236
- final_score = text_score
237
- image_score = 0.0
 
 
 
 
238
 
239
- item_has_image = 'shape_features' in item and item['shape_features'] is not None
240
-
241
- if query_has_image and item_has_image:
242
- found_shape = np.array(item['shape_features'])
243
- found_color = np.array(item['color_features']).astype("float32")
244
- found_texture = np.array(item['texture_features']).astype("float32")
245
-
246
- shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
247
- shape_score = 1.0 / (1.0 + shape_dist)
248
-
249
- color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
250
- texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
251
-
252
- image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
253
- FEATURE_WEIGHTS["color"] * color_score +
254
- FEATURE_WEIGHTS["texture"] * texture_score)
255
-
256
- final_score = 0.4 * image_score + 0.6 * text_score
257
- print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
258
 
 
 
259
  results.append({
260
  "_id": item_id,
261
  "score": round(final_score, 4),
@@ -279,4 +269,4 @@ def compare_items():
279
  return jsonify({"error": str(e)}), 500
280
 
281
  if __name__ == '__main__':
282
- app.run(host='0.0.0.0', port=7860)
 
27
  # ---- Load Models ----
28
  print("="*50)
29
  print("πŸš€ Initializing application and loading models...")
30
+ device_name = os.environ.get("device", "cpu")
31
+ device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu')
32
  print(f"🧠 Using device: {device}")
33
 
34
  print("...Loading Grounding DINO model...")
 
118
 
119
  color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
120
  cv2.normalize(color_hist, color_hist)
121
+ color_hist = color_hist.flatten()
 
 
 
 
 
122
 
123
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
124
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
 
128
 
129
  return {
130
  "shape_features": hu_moments.tolist(),
131
+ "color_features": color_hist.tolist(),
132
  "texture_features": texture_hist.tolist()
133
  }
134
 
 
163
 
164
  object_name = data.get('objectName')
165
  description = data.get('objectDescription')
166
+ image_url = data.get('objectImage') # This can now be null
167
 
168
  if not all([object_name, description]):
169
  return jsonify({"error": "objectName and objectDescription are required."}), 400
170
 
171
+ # --- Always process text-based features ---
172
  canonical_label = get_canonical_label(object_name)
173
  text_embedding = get_text_embedding(description)
174
 
 
177
  "text_embedding": text_embedding,
178
  }
179
 
180
+ # --- Process visual features ONLY if an image_url is provided ---
181
  if image_url:
182
  print("--- Image URL provided, processing visual features... ---")
183
  image = download_image_from_url(image_url)
184
  object_crop = detect_and_crop(image, canonical_label)
185
  visual_features = extract_features(object_crop)
186
+ # Add visual features to the response
187
  response_data.update(visual_features)
188
  else:
189
  print("--- No image URL provided, skipping visual feature extraction. ---")
 
213
  return jsonify({"error": "queryItem and searchList are required."}), 400
214
 
215
  query_text_emb = np.array(query_item['text_embedding'])
216
+ query_shape_feat = np.array(query_item['shape_features'])
217
+ query_color_feat = np.array(query_item['color_features']).astype("float32")
218
+ query_texture_feat = np.array(query_item['texture_features']).astype("float32")
 
 
 
219
 
220
  results = []
221
  print(f"--- Comparing 1 query item against {len(search_list)} items ---")
 
228
  text_score = cosine_similarity(query_text_emb, text_emb_found)
229
  print(f" - Text Score: {text_score:.4f}")
230
 
231
+ found_shape = np.array(item['shape_features'])
232
+ found_color = np.array(item['color_features']).astype("float32")
233
+ found_texture = np.array(item['texture_features']).astype("float32")
234
+
235
+ shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
236
+ shape_score = 1.0 / (1.0 + shape_dist)
237
 
238
+ color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
239
+ texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
240
+
241
+ image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
242
+ FEATURE_WEIGHTS["color"] * color_score +
243
+ FEATURE_WEIGHTS["texture"] * texture_score)
244
+
245
+ final_score = 0.4 * image_score + 0.6 * text_score
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
248
+
249
  results.append({
250
  "_id": item_id,
251
  "score": round(final_score, 4),
 
269
  return jsonify({"error": str(e)}), 500
270
 
271
  if __name__ == '__main__':
272
+ app.run(host='0.0.0.0', port=7860)