File size: 7,506 Bytes
7a4305a 754ba00 7a4305a 754ba00 7a4305a 754ba00 7a4305a 754ba00 7a4305a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import traceback
import numpy as np
import cv2
from flask import request, jsonify
# Import app, models, and logic functions
from pipeline import app, models, logic
@app.route('/process', methods=['POST'])
def process_item():
print("\n" + "="*50)
print("β‘ [Request] Received new request to /process")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
object_name = data.get('objectName')
description = data.get('objectDescription')
image_url = data.get('objectImage')
if not all([object_name, description]):
return jsonify({"error": "objectName and objectDescription are required."}), 400
canonical_label = logic.get_canonical_label(object_name)
text_embedding = logic.get_text_embedding(description, models)
response_data = {
"canonicalLabel": canonical_label,
"text_embedding": text_embedding,
}
if image_url:
print("--- Image URL provided, processing visual features... ---")
image = logic.download_image_from_url(image_url)
object_crop = logic.detect_and_crop(image, canonical_label, models)
visual_features = logic.extract_features(object_crop)
response_data.update(visual_features)
else:
print("--- No image URL provided, skipping visual feature extraction. ---")
print("β
Successfully processed item.")
print("="*50)
return jsonify(response_data), 200
except Exception as e:
print(f"β Error in /process: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/compare', methods=['POST'])
# Add a new constant at the top of the file
TOP_N_CANDIDATES = 20 # The number of items to re-rank
@app.route('/compare', methods=['POST'])
def compare_items():
print("\n" + "="*50)
print("β‘ [Request] Received new request to /compare")
try:
data = request.get_json()
if not data: return jsonify({"error": "Invalid JSON payload"}), 400
query_item = data.get('queryItem')
search_list = data.get('searchList')
if not all([query_item, search_list]):
return jsonify({"error": "queryItem and searchList are required."}), 400
# === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
initial_candidates = []
query_text_emb = np.array(query_item['text_embedding'])
for item in search_list:
text_emb_found = np.array(item['text_embedding'])
text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
# For now, just use the text_score as the initial score
# We will calculate the full score later for the top candidates
initial_candidates.append({"item": item, "initial_score": text_score})
# Sort by the initial score and keep the best ones
initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
top_candidates = initial_candidates[:TOP_N_CANDIDATES]
print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")
# === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
if not top_candidates:
print("β
No potential matches found in Stage 1.")
return jsonify({"matches": []}), 200
print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
query_description = query_item['objectDescription']
# Create pairs of [query, candidate_description] for the cross-encoder
rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
# Get new, highly accurate scores from the cross-encoder
cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)
# Now, build the final results with the new scores
final_results = []
for i, candidate_data in enumerate(top_candidates):
item = candidate_data['item']
cross_score = cross_encoder_scores[i] # Get the new text score
print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
print(f" - Cross-Encoder Score: {cross_score:.4f}")
# Now we calculate the final image and combined score, just like before
has_query_image = 'shape_features' in query_item and query_item['shape_features']
has_item_image = 'shape_features' in item and item['shape_features']
if has_query_image and has_item_image:
# (This image scoring logic is the same as your old code)
from pipeline import FEATURE_WEIGHTS
query_shape = np.array(query_item['shape_features'])
query_color = np.array(query_item['color_features']).astype("float32")
query_texture = np.array(query_item['texture_features']).astype("float32")
found_shape = np.array(item['shape_features'])
found_color = np.array(item['color_features']).astype("float32")
found_texture = np.array(item['texture_features']).astype("float32")
shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
shape_score = 1.0 / (1.0 + shape_dist)
color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
FEATURE_WEIGHTS["color"] * color_score +
FEATURE_WEIGHTS["texture"] * texture_score)
image_score = logic.stretch_image_score(raw_image_score)
# Use the new cross_score for the text part
final_score = 0.4 * image_score + 0.6 * cross_score
print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
else:
final_score = cross_score # If no image, the final score is the cross-encoder score
from pipeline import FINAL_SCORE_THRESHOLD
if final_score >= FINAL_SCORE_THRESHOLD:
print(f" - β
ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
final_results.append({
"_id": item.get('_id'),
"score": round(final_score, 4),
"objectName": item.get("objectName"),
"objectDescription": item.get("objectDescription"),
"objectImage": item.get("objectImage"),
})
else:
print(f" - β REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
final_results.sort(key=lambda x: x["score"], reverse=True)
print(f"\nβ
Search complete. Found {len(final_results)} final matches after re-ranking.")
print("="*50)
return jsonify({"matches": final_results}), 200
except Exception as e:
print(f"β Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
except Exception as e:
print(f"β Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500 |