sohamnk commited on
Commit
b6cabef
Β·
verified Β·
1 Parent(s): dbeb7ca

Update pipeline/routes.py

Browse files
Files changed (1) hide show
  1. pipeline/routes.py +76 -85
pipeline/routes.py CHANGED
@@ -6,16 +6,16 @@ from flask import request, jsonify
6
  # Import app, models, and logic functions
7
  from pipeline import app, models, logic
8
 
9
- # This constant should be at the top level, after imports
10
- TOP_N_CANDIDATES = 20
11
 
12
  @app.route('/process', methods=['POST'])
13
  def process_item():
14
- print("\n" + "="*50)
15
  print("➑ [Request] Received new request to /process")
 
16
  try:
17
  data = request.get_json()
18
- if not data: return jsonify({"error": "Invalid JSON payload"}), 400
 
19
 
20
  object_name = data.get('objectName')
21
  description = data.get('objectDescription')
@@ -42,7 +42,7 @@ def process_item():
42
  print("--- No image URL provided, skipping visual feature extraction. ---")
43
 
44
  print("βœ… Successfully processed item.")
45
- print("="*50)
46
  return jsonify(response_data), 200
47
 
48
  except Exception as e:
@@ -50,13 +50,16 @@ def process_item():
50
  traceback.print_exc()
51
  return jsonify({"error": str(e)}), 500
52
 
 
53
  @app.route('/compare', methods=['POST'])
54
  def compare_items():
55
- print("\n" + "="*50)
56
  print("➑ [Request] Received new request to /compare")
 
57
  try:
58
  data = request.get_json()
59
- if not data: return jsonify({"error": "Invalid JSON payload"}), 400
 
60
 
61
  query_item = data.get('queryItem')
62
  search_list = data.get('searchList')
@@ -64,88 +67,76 @@ def compare_items():
64
  if not all([query_item, search_list]):
65
  return jsonify({"error": "queryItem and searchList are required."}), 400
66
 
67
- # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
68
- print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
69
-
70
- initial_candidates = []
71
  query_text_emb = np.array(query_item['text_embedding'])
 
 
72
 
73
  for item in search_list:
74
- text_emb_found = np.array(item['text_embedding'])
75
- text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
76
- initial_candidates.append({"item": item, "initial_score": text_score})
77
-
78
- # Sort by the initial score and keep the best ones
79
- initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
80
- top_candidates = initial_candidates[:TOP_N_CANDIDATES]
81
- print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")
82
-
83
- # === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
84
- if not top_candidates:
85
- print("βœ… No potential matches found in Stage 1.")
86
- return jsonify({"matches": []}), 200
87
-
88
- print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
89
- query_description = query_item['objectDescription']
90
-
91
- # Create pairs of [query, candidate_description] for the cross-encoder
92
- rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
93
-
94
- # Get new, highly accurate scores from the cross-encoder
95
- cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
96
-
97
- # Now, build the final results with the new scores
98
- final_results = []
99
- for i, candidate_data in enumerate(top_candidates):
100
- item = candidate_data['item']
101
- cross_score = cross_encoder_scores[i] # Get the new text score
102
- print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
103
- print(f" - Cross-Encoder Score: {cross_score:.4f}")
104
-
105
- # Now we calculate the final image and combined score, just like before
106
- has_query_image = 'shape_features' in query_item and query_item['shape_features']
107
- has_item_image = 'shape_features' in item and item['shape_features']
108
-
109
- if has_query_image and has_item_image:
110
- from pipeline import FEATURE_WEIGHTS
111
- query_shape = np.array(query_item['shape_features'])
112
- query_color = np.array(query_item['color_features']).astype("float32")
113
- query_texture = np.array(query_item['texture_features']).astype("float32")
114
- found_shape = np.array(item['shape_features'])
115
- found_color = np.array(item['color_features']).astype("float32")
116
- found_texture = np.array(item['texture_features']).astype("float32")
117
- shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
118
- shape_score = 1.0 / (1.0 + shape_dist)
119
- color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
120
- texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
121
- raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
122
- FEATURE_WEIGHTS["color"] * color_score +
123
- FEATURE_WEIGHTS["texture"] * texture_score)
124
- image_score = logic.stretch_image_score(raw_image_score)
125
- final_score = 0.4 * image_score + 0.6 * cross_score
126
- print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
127
- else:
128
- final_score = cross_score
129
-
130
- from pipeline import FINAL_SCORE_THRESHOLD
131
- if final_score >= FINAL_SCORE_THRESHOLD:
132
- print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
133
- final_results.append({
134
- "_id": item.get('_id'),
135
- "score": round(final_score, 4),
136
- "objectName": item.get("objectName"),
137
- "objectDescription": item.get("objectDescription"),
138
- "objectImage": item.get("objectImage"),
139
- })
140
- else:
141
- print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
142
-
143
- final_results.sort(key=lambda x: x["score"], reverse=True)
144
- print(f"\nβœ… Search complete. Found {len(final_results)} final matches after re-ranking.")
145
- print("="*50)
146
- return jsonify({"matches": final_results}), 200
147
 
148
  except Exception as e:
149
  print(f"❌ Error in /compare: {e}")
150
  traceback.print_exc()
151
- return jsonify({"error": str(e)}), 500
 
6
  # Import app, models, and logic functions
7
  from pipeline import app, models, logic
8
 
 
 
9
 
10
  @app.route('/process', methods=['POST'])
11
  def process_item():
12
+ print("\n" + "=" * 50)
13
  print("➑ [Request] Received new request to /process")
14
+
15
  try:
16
  data = request.get_json()
17
+ if not data:
18
+ return jsonify({"error": "Invalid JSON payload"}), 400
19
 
20
  object_name = data.get('objectName')
21
  description = data.get('objectDescription')
 
42
  print("--- No image URL provided, skipping visual feature extraction. ---")
43
 
44
  print("βœ… Successfully processed item.")
45
+ print("=" * 50)
46
  return jsonify(response_data), 200
47
 
48
  except Exception as e:
 
50
  traceback.print_exc()
51
  return jsonify({"error": str(e)}), 500
52
 
53
+
54
  @app.route('/compare', methods=['POST'])
55
  def compare_items():
56
+ print("\n" + "=" * 50)
57
  print("➑ [Request] Received new request to /compare")
58
+
59
  try:
60
  data = request.get_json()
61
+ if not data:
62
+ return jsonify({"error": "Invalid JSON payload"}), 400
63
 
64
  query_item = data.get('queryItem')
65
  search_list = data.get('searchList')
 
67
  if not all([query_item, search_list]):
68
  return jsonify({"error": "queryItem and searchList are required."}), 400
69
 
 
 
 
 
70
  query_text_emb = np.array(query_item['text_embedding'])
71
+ results = []
72
+ print(f"--- Comparing 1 query item against {len(search_list)} items ---")
73
 
74
  for item in search_list:
75
+ item_id = item.get('_id')
76
+ print(f"\n [Checking] Item ID: {item_id}")
77
+
78
+ try:
79
+ text_emb_found = np.array(item['text_embedding'])
80
+ text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
81
+ print(f" - Text Score: {text_score:.4f}")
82
+
83
+ has_query_image = 'shape_features' in query_item and query_item['shape_features']
84
+ has_item_image = 'shape_features' in item and item['shape_features']
85
+
86
+ if has_query_image and has_item_image:
87
+ print(" - Both items have images. Performing visual comparison.")
88
+ from pipeline import FEATURE_WEIGHTS # Import constant
89
+
90
+ query_shape = np.array(query_item['shape_features'])
91
+ query_color = np.array(query_item['color_features']).astype("float32")
92
+ query_texture = np.array(query_item['texture_features']).astype("float32")
93
+
94
+ found_shape = np.array(item['shape_features'])
95
+ found_color = np.array(item['color_features']).astype("float32")
96
+ found_texture = np.array(item['texture_features']).astype("float32")
97
+
98
+ shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
99
+ shape_score = 1.0 / (1.0 + shape_dist)
100
+ color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
101
+ texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
102
+
103
+ raw_image_score = (
104
+ FEATURE_WEIGHTS["shape"] * shape_score +
105
+ FEATURE_WEIGHTS["color"] * color_score +
106
+ FEATURE_WEIGHTS["texture"] * texture_score
107
+ )
108
+
109
+ print(f"Raw Image Score: {raw_image_score:.4f}")
110
+ image_score = logic.stretch_image_score(raw_image_score)
111
+ final_score = 0.4 * image_score + 0.6 * text_score
112
+ print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
113
+ else:
114
+ print(" - One or both items missing image. Using text score only.")
115
+ final_score = text_score
116
+
117
+ from pipeline import FINAL_SCORE_THRESHOLD # Import constant
118
+ if final_score >= FINAL_SCORE_THRESHOLD:
119
+ print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
120
+ results.append({
121
+ "_id": item_id,
122
+ "score": round(final_score, 4),
123
+ "objectName": item.get("objectName"),
124
+ "objectDescription": item.get("objectDescription"),
125
+ "objectImage": item.get("objectImage"),
126
+ })
127
+ else:
128
+ print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
129
+
130
+ except Exception as e:
131
+ print(f" [Skipping] Item {item_id} due to processing error: {e}")
132
+ continue
133
+
134
+ results.sort(key=lambda x: x["score"], reverse=True)
135
+ print(f"\nβœ… Search complete. Found {len(results)} potential matches.")
136
+ print("=" * 50)
137
+ return jsonify({"matches": results}), 200
 
 
 
 
 
 
 
 
 
 
138
 
139
  except Exception as e:
140
  print(f"❌ Error in /compare: {e}")
141
  traceback.print_exc()
142
+ return jsonify({"error": str(e)}), 500