sohamnk commited on
Commit
0d341d8
·
verified ·
1 Parent(s): 9e8ec58

Update pipeline/routes.py

Browse files
Files changed (1) hide show
  1. pipeline/routes.py +11 -22
pipeline/routes.py CHANGED
@@ -6,6 +6,9 @@ from flask import request, jsonify
6
  # Import app, models, and logic functions
7
  from pipeline import app, models, logic
8
 
 
 
 
9
  @app.route('/process', methods=['POST'])
10
  def process_item():
11
  print("\n" + "="*50)
@@ -16,7 +19,7 @@ def process_item():
16
 
17
  object_name = data.get('objectName')
18
  description = data.get('objectDescription')
19
- image_url = data.get('objectImage')
20
 
21
  if not all([object_name, description]):
22
  return jsonify({"error": "objectName and objectDescription are required."}), 400
@@ -47,10 +50,6 @@ def process_item():
47
  traceback.print_exc()
48
  return jsonify({"error": str(e)}), 500
49
 
50
- @app.route('/compare', methods=['POST'])
51
- # Add a new constant at the top of the file
52
- TOP_N_CANDIDATES = 20 # The number of items to re-rank
53
-
54
  @app.route('/compare', methods=['POST'])
55
  def compare_items():
56
  print("\n" + "="*50)
@@ -64,19 +63,16 @@ def compare_items():
64
 
65
  if not all([query_item, search_list]):
66
  return jsonify({"error": "queryItem and searchList are required."}), 400
67
-
68
  # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
69
  print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
70
-
71
  initial_candidates = []
72
  query_text_emb = np.array(query_item['text_embedding'])
73
 
74
  for item in search_list:
75
  text_emb_found = np.array(item['text_embedding'])
76
  text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
77
-
78
- # For now, just use the text_score as the initial score
79
- # We will calculate the full score later for the top candidates
80
  initial_candidates.append({"item": item, "initial_score": text_score})
81
 
82
  # Sort by the initial score and keep the best ones
@@ -91,10 +87,10 @@ def compare_items():
91
 
92
  print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
93
  query_description = query_item['objectDescription']
94
-
95
  # Create pairs of [query, candidate_description] for the cross-encoder
96
  rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
97
-
98
  # Get new, highly accurate scores from the cross-encoder
99
  cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
100
 
@@ -109,9 +105,8 @@ def compare_items():
109
  # Now we calculate the final image and combined score, just like before
110
  has_query_image = 'shape_features' in query_item and query_item['shape_features']
111
  has_item_image = 'shape_features' in item and item['shape_features']
112
-
113
  if has_query_image and has_item_image:
114
- # (This image scoring logic is the same as your old code)
115
  from pipeline import FEATURE_WEIGHTS
116
  query_shape = np.array(query_item['shape_features'])
117
  query_color = np.array(query_item['color_features']).astype("float32")
@@ -127,11 +122,10 @@ def compare_items():
127
  FEATURE_WEIGHTS["color"] * color_score +
128
  FEATURE_WEIGHTS["texture"] * texture_score)
129
  image_score = logic.stretch_image_score(raw_image_score)
130
- # Use the new cross_score for the text part
131
- final_score = 0.4 * image_score + 0.6 * cross_score
132
  print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
133
  else:
134
- final_score = cross_score # If no image, the final score is the cross-encoder score
135
 
136
  from pipeline import FINAL_SCORE_THRESHOLD
137
  if final_score >= FINAL_SCORE_THRESHOLD:
@@ -151,11 +145,6 @@ def compare_items():
151
  print("="*50)
152
  return jsonify({"matches": final_results}), 200
153
 
154
- except Exception as e:
155
- print(f"❌ Error in /compare: {e}")
156
- traceback.print_exc()
157
- return jsonify({"error": str(e)}), 500
158
-
159
  except Exception as e:
160
  print(f"❌ Error in /compare: {e}")
161
  traceback.print_exc()
 
6
  # Import app, models, and logic functions
7
  from pipeline import app, models, logic
8
 
9
+ # This constant should be at the top level, after imports
10
+ TOP_N_CANDIDATES = 20
11
+
12
  @app.route('/process', methods=['POST'])
13
  def process_item():
14
  print("\n" + "="*50)
 
19
 
20
  object_name = data.get('objectName')
21
  description = data.get('objectDescription')
22
+ image_url = data.get('objectImage')
23
 
24
  if not all([object_name, description]):
25
  return jsonify({"error": "objectName and objectDescription are required."}), 400
 
50
  traceback.print_exc()
51
  return jsonify({"error": str(e)}), 500
52
 
 
 
 
 
53
  @app.route('/compare', methods=['POST'])
54
  def compare_items():
55
  print("\n" + "="*50)
 
63
 
64
  if not all([query_item, search_list]):
65
  return jsonify({"error": "queryItem and searchList are required."}), 400
66
+
67
  # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
68
  print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
69
+
70
  initial_candidates = []
71
  query_text_emb = np.array(query_item['text_embedding'])
72
 
73
  for item in search_list:
74
  text_emb_found = np.array(item['text_embedding'])
75
  text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
 
 
 
76
  initial_candidates.append({"item": item, "initial_score": text_score})
77
 
78
  # Sort by the initial score and keep the best ones
 
87
 
88
  print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
89
  query_description = query_item['objectDescription']
90
+
91
  # Create pairs of [query, candidate_description] for the cross-encoder
92
  rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
93
+
94
  # Get new, highly accurate scores from the cross-encoder
95
  cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
96
 
 
105
  # Now we calculate the final image and combined score, just like before
106
  has_query_image = 'shape_features' in query_item and query_item['shape_features']
107
  has_item_image = 'shape_features' in item and item['shape_features']
108
+
109
  if has_query_image and has_item_image:
 
110
  from pipeline import FEATURE_WEIGHTS
111
  query_shape = np.array(query_item['shape_features'])
112
  query_color = np.array(query_item['color_features']).astype("float32")
 
122
  FEATURE_WEIGHTS["color"] * color_score +
123
  FEATURE_WEIGHTS["texture"] * texture_score)
124
  image_score = logic.stretch_image_score(raw_image_score)
125
+ final_score = 0.4 * image_score + 0.6 * cross_score
 
126
  print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
127
  else:
128
+ final_score = cross_score
129
 
130
  from pipeline import FINAL_SCORE_THRESHOLD
131
  if final_score >= FINAL_SCORE_THRESHOLD:
 
145
  print("="*50)
146
  return jsonify({"matches": final_results}), 200
147
 
 
 
 
 
 
148
  except Exception as e:
149
  print(f"❌ Error in /compare: {e}")
150
  traceback.print_exc()