File size: 7,506 Bytes
7a4305a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754ba00
 
 
 
7a4305a
 
 
 
 
 
 
 
 
 
 
 
 
 
754ba00
 
 
 
7a4305a
 
 
754ba00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4305a
754ba00
 
 
 
 
 
7a4305a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import traceback
import numpy as np
import cv2
from flask import request, jsonify

# Import app, models, and logic functions
from pipeline import app, models, logic

@app.route('/process', methods=['POST'])
def process_item():
    print("\n" + "="*50)
    print("➑ [Request] Received new request to /process")
    try:
        data = request.get_json()
        if not data: return jsonify({"error": "Invalid JSON payload"}), 400

        object_name = data.get('objectName')
        description = data.get('objectDescription')
        image_url = data.get('objectImage') 

        if not all([object_name, description]):
            return jsonify({"error": "objectName and objectDescription are required."}), 400

        canonical_label = logic.get_canonical_label(object_name)
        text_embedding = logic.get_text_embedding(description, models)

        response_data = {
            "canonicalLabel": canonical_label,
            "text_embedding": text_embedding,
        }

        if image_url:
            print("--- Image URL provided, processing visual features... ---")
            image = logic.download_image_from_url(image_url)
            object_crop = logic.detect_and_crop(image, canonical_label, models)
            visual_features = logic.extract_features(object_crop)
            response_data.update(visual_features)
        else:
            print("--- No image URL provided, skipping visual feature extraction. ---")

        print("βœ… Successfully processed item.")
        print("="*50)
        return jsonify(response_data), 200

    except Exception as e:
        print(f"❌ Error in /process: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@app.route('/compare', methods=['POST'])
# Add a new constant at the top of the file
TOP_N_CANDIDATES = 20 # The number of items to re-rank

@app.route('/compare', methods=['POST'])
def compare_items():
    print("\n" + "="*50)
    print("➑ [Request] Received new request to /compare")
    try:
        data = request.get_json()
        if not data: return jsonify({"error": "Invalid JSON payload"}), 400

        query_item = data.get('queryItem')
        search_list = data.get('searchList')

        if not all([query_item, search_list]):
            return jsonify({"error": "queryItem and searchList are required."}), 400
        
        # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
        print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")
        
        initial_candidates = []
        query_text_emb = np.array(query_item['text_embedding'])

        for item in search_list:
            text_emb_found = np.array(item['text_embedding'])
            text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
            
            # For now, just use the text_score as the initial score
            # We will calculate the full score later for the top candidates
            initial_candidates.append({"item": item, "initial_score": text_score})

        # Sort by the initial score and keep the best ones
        initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
        top_candidates = initial_candidates[:TOP_N_CANDIDATES]
        print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")

        # === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
        if not top_candidates:
            print("βœ… No potential matches found in Stage 1.")
            return jsonify({"matches": []}), 200

        print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
        query_description = query_item['objectDescription']
        
        # Create pairs of [query, candidate_description] for the cross-encoder
        rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]
        
        # Get new, highly accurate scores from the cross-encoder
        cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)

        # Now, build the final results with the new scores
        final_results = []
        for i, candidate_data in enumerate(top_candidates):
            item = candidate_data['item']
            cross_score = cross_encoder_scores[i] # Get the new text score
            print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
            print(f" - Cross-Encoder Score: {cross_score:.4f}")

            # Now we calculate the final image and combined score, just like before
            has_query_image = 'shape_features' in query_item and query_item['shape_features']
            has_item_image = 'shape_features' in item and item['shape_features']
            
            if has_query_image and has_item_image:
                # (This image scoring logic is the same as your old code)
                from pipeline import FEATURE_WEIGHTS
                query_shape = np.array(query_item['shape_features'])
                query_color = np.array(query_item['color_features']).astype("float32")
                query_texture = np.array(query_item['texture_features']).astype("float32")
                found_shape = np.array(item['shape_features'])
                found_color = np.array(item['color_features']).astype("float32")
                found_texture = np.array(item['texture_features']).astype("float32")
                shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
                shape_score = 1.0 / (1.0 + shape_dist)
                color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
                texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
                raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
                                   FEATURE_WEIGHTS["color"] * color_score +
                                   FEATURE_WEIGHTS["texture"] * texture_score)
                image_score = logic.stretch_image_score(raw_image_score)
                # Use the new cross_score for the text part
                final_score = 0.4 * image_score + 0.6 * cross_score 
                print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
            else:
                final_score = cross_score # If no image, the final score is the cross-encoder score

            from pipeline import FINAL_SCORE_THRESHOLD
            if final_score >= FINAL_SCORE_THRESHOLD:
                print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
                final_results.append({
                    "_id": item.get('_id'),
                    "score": round(final_score, 4),
                    "objectName": item.get("objectName"),
                    "objectDescription": item.get("objectDescription"),
                    "objectImage": item.get("objectImage"),
                })
            else:
                print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")

        final_results.sort(key=lambda x: x["score"], reverse=True)
        print(f"\nβœ… Search complete. Found {len(final_results)} final matches after re-ranking.")
        print("="*50)
        return jsonify({"matches": final_results}), 200

    except Exception as e:
        print(f"❌ Error in /compare: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

    except Exception as e:
        print(f"❌ Error in /compare: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500