sohamnk commited on
Commit
754ba00
Β·
verified Β·
1 Parent(s): ee57fa4

Update pipeline/routes.py

Browse files
Files changed (1) hide show
  1. pipeline/routes.py +90 -55
pipeline/routes.py CHANGED
@@ -47,6 +47,10 @@ def process_item():
47
  traceback.print_exc()
48
  return jsonify({"error": str(e)}), 500
49
 
 
 
 
 
50
  @app.route('/compare', methods=['POST'])
51
  def compare_items():
52
  print("\n" + "="*50)
@@ -61,65 +65,96 @@ def compare_items():
61
  if not all([query_item, search_list]):
62
  return jsonify({"error": "queryItem and searchList are required."}), 400
63
 
 
 
 
 
64
  query_text_emb = np.array(query_item['text_embedding'])
65
- results = []
66
- print(f"--- Comparing 1 query item against {len(search_list)} items ---")
67
 
68
  for item in search_list:
69
- item_id = item.get('_id')
70
- print(f"\n [Checking] Item ID: {item_id}")
71
- try:
72
- text_emb_found = np.array(item['text_embedding'])
73
- text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
74
- print(f" - Text Score: {text_score:.4f}")
75
-
76
- has_query_image = 'shape_features' in query_item and query_item['shape_features']
77
- has_item_image = 'shape_features' in item and item['shape_features']
78
-
79
- if has_query_image and has_item_image:
80
- print(" - Both items have images. Performing visual comparison.")
81
- from pipeline import FEATURE_WEIGHTS # Import constant
82
- query_shape = np.array(query_item['shape_features'])
83
- query_color = np.array(query_item['color_features']).astype("float32")
84
- query_texture = np.array(query_item['texture_features']).astype("float32")
85
- found_shape = np.array(item['shape_features'])
86
- found_color = np.array(item['color_features']).astype("float32")
87
- found_texture = np.array(item['texture_features']).astype("float32")
88
- shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
89
- shape_score = 1.0 / (1.0 + shape_dist)
90
- color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
91
- texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
92
- raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
93
- FEATURE_WEIGHTS["color"] * color_score +
94
- FEATURE_WEIGHTS["texture"] * texture_score)
95
- print(f"Raw Image Score: {raw_image_score:.4f}")
96
- image_score = logic.stretch_image_score(raw_image_score)
97
- final_score = 0.4 * image_score + 0.6 * text_score
98
- print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
99
- else:
100
- print(" - One or both items missing image. Using text score only.")
101
- final_score = text_score
102
-
103
- from pipeline import FINAL_SCORE_THRESHOLD # Import constant
104
- if final_score >= FINAL_SCORE_THRESHOLD:
105
- print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
106
- results.append({
107
- "_id": item_id,
108
- "score": round(final_score, 4),
109
- "objectName": item.get("objectName"),
110
- "objectDescription": item.get("objectDescription"),
111
- "objectImage": item.get("objectImage"),
112
- })
113
- else:
114
- print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
115
- except Exception as e:
116
- print(f" [Skipping] Item {item_id} due to processing error: {e}")
117
- continue
118
-
119
- results.sort(key=lambda x: x["score"], reverse=True)
120
- print(f"\nβœ… Search complete. Found {len(results)} potential matches.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  print("="*50)
122
- return jsonify({"matches": results}), 200
 
 
 
 
 
123
 
124
  except Exception as e:
125
  print(f"❌ Error in /compare: {e}")
 
47
  traceback.print_exc()
48
  return jsonify({"error": str(e)}), 500
49
 
50
+ @app.route('/compare', methods=['POST'])
51
+ # Add a new constant at the top of the file
52
+ TOP_N_CANDIDATES = 20 # The number of items to re-rank
53
+
54
  @app.route('/compare', methods=['POST'])
55
  def compare_items():
56
  print("\n" + "="*50)
 
65
  if not all([query_item, search_list]):
66
  return jsonify({"error": "queryItem and searchList are required."}), 400
67
 
68
+ # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
69
+ print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
70
+
71
+ initial_candidates = []
72
  query_text_emb = np.array(query_item['text_embedding'])
 
 
73
 
74
  for item in search_list:
75
+ text_emb_found = np.array(item['text_embedding'])
76
+ text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
77
+
78
+ # For now, just use the text_score as the initial score
79
+ # We will calculate the full score later for the top candidates
80
+ initial_candidates.append({"item": item, "initial_score": text_score})
81
+
82
+ # Sort by the initial score and keep the best ones
83
+ initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
84
+ top_candidates = initial_candidates[:TOP_N_CANDIDATES]
85
+ print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")
86
+
87
+ # === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
88
+ if not top_candidates:
89
+ print("βœ… No potential matches found in Stage 1.")
90
+ return jsonify({"matches": []}), 200
91
+
92
+ print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
93
+ query_description = query_item['objectDescription']
94
+
95
+ # Create pairs of [query, candidate_description] for the cross-encoder
96
+ rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
97
+
98
+ # Get new, highly accurate scores from the cross-encoder
99
+ cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
100
+
101
+ # Now, build the final results with the new scores
102
+ final_results = []
103
+ for i, candidate_data in enumerate(top_candidates):
104
+ item = candidate_data['item']
105
+ cross_score = cross_encoder_scores[i] # Get the new text score
106
+ print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
107
+ print(f" - Cross-Encoder Score: {cross_score:.4f}")
108
+
109
+ # Now we calculate the final image and combined score, just like before
110
+ has_query_image = 'shape_features' in query_item and query_item['shape_features']
111
+ has_item_image = 'shape_features' in item and item['shape_features']
112
+
113
+ if has_query_image and has_item_image:
114
+ # (This image scoring logic is the same as your old code)
115
+ from pipeline import FEATURE_WEIGHTS
116
+ query_shape = np.array(query_item['shape_features'])
117
+ query_color = np.array(query_item['color_features']).astype("float32")
118
+ query_texture = np.array(query_item['texture_features']).astype("float32")
119
+ found_shape = np.array(item['shape_features'])
120
+ found_color = np.array(item['color_features']).astype("float32")
121
+ found_texture = np.array(item['texture_features']).astype("float32")
122
+ shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
123
+ shape_score = 1.0 / (1.0 + shape_dist)
124
+ color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
125
+ texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
126
+ raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
127
+ FEATURE_WEIGHTS["color"] * color_score +
128
+ FEATURE_WEIGHTS["texture"] * texture_score)
129
+ image_score = logic.stretch_image_score(raw_image_score)
130
+ # Use the new cross_score for the text part
131
+ final_score = 0.4 * image_score + 0.6 * cross_score
132
+ print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
133
+ else:
134
+ final_score = cross_score # If no image, the final score is the cross-encoder score
135
+
136
+ from pipeline import FINAL_SCORE_THRESHOLD
137
+ if final_score >= FINAL_SCORE_THRESHOLD:
138
+ print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
139
+ final_results.append({
140
+ "_id": item.get('_id'),
141
+ "score": round(final_score, 4),
142
+ "objectName": item.get("objectName"),
143
+ "objectDescription": item.get("objectDescription"),
144
+ "objectImage": item.get("objectImage"),
145
+ })
146
+ else:
147
+ print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
148
+
149
+ final_results.sort(key=lambda x: x["score"], reverse=True)
150
+ print(f"\nβœ… Search complete. Found {len(final_results)} final matches after re-ranking.")
151
  print("="*50)
152
+ return jsonify({"matches": final_results}), 200
153
+
154
+ except Exception as e:
155
+ print(f"❌ Error in /compare: {e}")
156
+ traceback.print_exc()
157
+ return jsonify({"error": str(e)}), 500
158
 
159
  except Exception as e:
160
  print(f"❌ Error in /compare: {e}")