sohamnk's picture
Update pipeline/routes.py
0d341d8 verified
raw
history blame
6.9 kB
import traceback
import numpy as np
import cv2
from flask import request, jsonify
# Import app, models, and logic functions
from pipeline import app, models, logic
# This constant should be at the top level, after imports
TOP_N_CANDIDATES = 20
@app.route('/process', methods=['POST'])
def process_item():
print("\n" + "="*50)
print("➑ [Request] Received new request to /process")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
object_name = data.get('objectName')
description = data.get('objectDescription')
image_url = data.get('objectImage')
if not all([object_name, description]):
return jsonify({"error": "objectName and objectDescription are required."}), 400
canonical_label = logic.get_canonical_label(object_name)
text_embedding = logic.get_text_embedding(description, models)
response_data = {
"canonicalLabel": canonical_label,
"text_embedding": text_embedding,
}
if image_url:
print("--- Image URL provided, processing visual features... ---")
image = logic.download_image_from_url(image_url)
object_crop = logic.detect_and_crop(image, canonical_label, models)
visual_features = logic.extract_features(object_crop)
response_data.update(visual_features)
else:
print("--- No image URL provided, skipping visual feature extraction. ---")
print("βœ… Successfully processed item.")
print("="*50)
return jsonify(response_data), 200
except Exception as e:
print(f"❌ Error in /process: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/compare', methods=['POST'])
def compare_items():
print("\n" + "="*50)
print("➑ [Request] Received new request to /compare")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
query_item = data.get('queryItem')
search_list = data.get('searchList')
if not all([query_item, search_list]):
return jsonify({"error": "queryItem and searchList are required."}), 400
# === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
initial_candidates = []
query_text_emb = np.array(query_item['text_embedding'])
for item in search_list:
text_emb_found = np.array(item['text_embedding'])
text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
initial_candidates.append({"item": item, "initial_score": text_score})
# Sort by the initial score and keep the best ones
initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
top_candidates = initial_candidates[:TOP_N_CANDIDATES]
print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")
# === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
if not top_candidates:
print("βœ… No potential matches found in Stage 1.")
return jsonify({"matches": []}), 200
print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
query_description = query_item['objectDescription']
# Create pairs of [query, candidate_description] for the cross-encoder
rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
# Get new, highly accurate scores from the cross-encoder
cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
# Now, build the final results with the new scores
final_results = []
for i, candidate_data in enumerate(top_candidates):
item = candidate_data['item']
cross_score = cross_encoder_scores[i] # Get the new text score
print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
print(f" - Cross-Encoder Score: {cross_score:.4f}")
# Now we calculate the final image and combined score, just like before
has_query_image = 'shape_features' in query_item and query_item['shape_features']
has_item_image = 'shape_features' in item and item['shape_features']
if has_query_image and has_item_image:
from pipeline import FEATURE_WEIGHTS
query_shape = np.array(query_item['shape_features'])
query_color = np.array(query_item['color_features']).astype("float32")
query_texture = np.array(query_item['texture_features']).astype("float32")
found_shape = np.array(item['shape_features'])
found_color = np.array(item['color_features']).astype("float32")
found_texture = np.array(item['texture_features']).astype("float32")
shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
shape_score = 1.0 / (1.0 + shape_dist)
color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
FEATURE_WEIGHTS["color"] * color_score +
FEATURE_WEIGHTS["texture"] * texture_score)
image_score = logic.stretch_image_score(raw_image_score)
final_score = 0.4 * image_score + 0.6 * cross_score
print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
else:
final_score = cross_score
from pipeline import FINAL_SCORE_THRESHOLD
if final_score >= FINAL_SCORE_THRESHOLD:
print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
final_results.append({
"_id": item.get('_id'),
"score": round(final_score, 4),
"objectName": item.get("objectName"),
"objectDescription": item.get("objectDescription"),
"objectImage": item.get("objectImage"),
})
else:
print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
final_results.sort(key=lambda x: x["score"], reverse=True)
print(f"\nβœ… Search complete. Found {len(final_results)} final matches after re-ranking.")
print("="*50)
return jsonify({"matches": final_results}), 200
except Exception as e:
print(f"❌ Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500