File size: 6,901 Bytes
7a4305a 0d341d8 7a4305a 0d341d8 7a4305a 0d341d8 754ba00 0d341d8 754ba00 7a4305a 754ba00 0d341d8 754ba00 0d341d8 754ba00 0d341d8 754ba00 0d341d8 754ba00 0d341d8 754ba00 7a4305a 754ba00 7a4305a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import traceback
import numpy as np
import cv2
from flask import request, jsonify
# Import app, models, and logic functions
from pipeline import app, models, logic
# This constant should be at the top level, after imports
TOP_N_CANDIDATES = 20
@app.route('/process', methods=['POST'])
def process_item():
print("\n" + "="*50)
print("β‘ [Request] Received new request to /process")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
object_name = data.get('objectName')
description = data.get('objectDescription')
image_url = data.get('objectImage')
if not all([object_name, description]):
return jsonify({"error": "objectName and objectDescription are required."}), 400
canonical_label = logic.get_canonical_label(object_name)
text_embedding = logic.get_text_embedding(description, models)
response_data = {
"canonicalLabel": canonical_label,
"text_embedding": text_embedding,
}
if image_url:
print("--- Image URL provided, processing visual features... ---")
image = logic.download_image_from_url(image_url)
object_crop = logic.detect_and_crop(image, canonical_label, models)
visual_features = logic.extract_features(object_crop)
response_data.update(visual_features)
else:
print("--- No image URL provided, skipping visual feature extraction. ---")
print("β
Successfully processed item.")
print("="*50)
return jsonify(response_data), 200
except Exception as e:
print(f"β Error in /process: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/compare', methods=['POST'])
def compare_items():
print("\n" + "="*50)
print("β‘ [Request] Received new request to /compare")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
query_item = data.get('queryItem')
search_list = data.get('searchList')
if not all([query_item, search_list]):
return jsonify({"error": "queryItem and searchList are required."}), 400
# === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
initial_candidates = []
query_text_emb = np.array(query_item['text_embedding'])
for item in search_list:
text_emb_found = np.array(item['text_embedding'])
text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
initial_candidates.append({"item": item, "initial_score": text_score})
# Sort by the initial score and keep the best ones
initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
top_candidates = initial_candidates[:TOP_N_CANDIDATES]
print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")
# === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
if not top_candidates:
print("β
No potential matches found in Stage 1.")
return jsonify({"matches": []}), 200
print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
query_description = query_item['objectDescription']
# Create pairs of [query, candidate_description] for the cross-encoder
rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
# Get new, highly accurate scores from the cross-encoder
cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
# Now, build the final results with the new scores
final_results = []
for i, candidate_data in enumerate(top_candidates):
item = candidate_data['item']
cross_score = cross_encoder_scores[i] # Get the new text score
print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
print(f" - Cross-Encoder Score: {cross_score:.4f}")
# Now we calculate the final image and combined score, just like before
has_query_image = 'shape_features' in query_item and query_item['shape_features']
has_item_image = 'shape_features' in item and item['shape_features']
if has_query_image and has_item_image:
from pipeline import FEATURE_WEIGHTS
query_shape = np.array(query_item['shape_features'])
query_color = np.array(query_item['color_features']).astype("float32")
query_texture = np.array(query_item['texture_features']).astype("float32")
found_shape = np.array(item['shape_features'])
found_color = np.array(item['color_features']).astype("float32")
found_texture = np.array(item['texture_features']).astype("float32")
shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
shape_score = 1.0 / (1.0 + shape_dist)
color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
FEATURE_WEIGHTS["color"] * color_score +
FEATURE_WEIGHTS["texture"] * texture_score)
image_score = logic.stretch_image_score(raw_image_score)
final_score = 0.4 * image_score + 0.6 * cross_score
print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
else:
final_score = cross_score
from pipeline import FINAL_SCORE_THRESHOLD
if final_score >= FINAL_SCORE_THRESHOLD:
print(f" - β
ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
final_results.append({
"_id": item.get('_id'),
"score": round(final_score, 4),
"objectName": item.get("objectName"),
"objectDescription": item.get("objectDescription"),
"objectImage": item.get("objectImage"),
})
else:
print(f" - β REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
final_results.sort(key=lambda x: x["score"], reverse=True)
print(f"\nβ
Search complete. Found {len(final_results)} final matches after re-ranking.")
print("="*50)
return jsonify({"matches": final_results}), 200
except Exception as e:
print(f"β Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500 |