File size: 6,901 Bytes
7a4305a
 
 
 
 
 
 
 
0d341d8
 
 
7a4305a
 
 
 
 
 
 
 
 
 
0d341d8
7a4305a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d341d8
754ba00
 
0d341d8
754ba00
7a4305a
 
 
754ba00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d341d8
754ba00
 
0d341d8
754ba00
 
 
 
 
 
 
 
 
 
 
 
 
 
0d341d8
754ba00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d341d8
754ba00
 
0d341d8
754ba00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4305a
754ba00
 
7a4305a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import traceback
import numpy as np
import cv2
from flask import request, jsonify

# Import app, models, and logic functions
from pipeline import app, models, logic

# This constant should be at the top level, after imports
TOP_N_CANDIDATES = 20

@app.route('/process', methods=['POST'])
def process_item():
    print("\n" + "="*50)
    print("➑ [Request] Received new request to /process")
    try:
        data = request.get_json()
        if not data: return jsonify({"error": "Invalid JSON payload"}), 400

        object_name = data.get('objectName')
        description = data.get('objectDescription')
        image_url = data.get('objectImage')

        if not all([object_name, description]):
            return jsonify({"error": "objectName and objectDescription are required."}), 400

        canonical_label = logic.get_canonical_label(object_name)
        text_embedding = logic.get_text_embedding(description, models)

        response_data = {
            "canonicalLabel": canonical_label,
            "text_embedding": text_embedding,
        }

        if image_url:
            print("--- Image URL provided, processing visual features... ---")
            image = logic.download_image_from_url(image_url)
            object_crop = logic.detect_and_crop(image, canonical_label, models)
            visual_features = logic.extract_features(object_crop)
            response_data.update(visual_features)
        else:
            print("--- No image URL provided, skipping visual feature extraction. ---")

        print("βœ… Successfully processed item.")
        print("="*50)
        return jsonify(response_data), 200

    except Exception as e:
        print(f"❌ Error in /process: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@app.route('/compare', methods=['POST'])
def compare_items():
    print("\n" + "="*50)
    print("➑ [Request] Received new request to /compare")
    try:
        data = request.get_json()
        if not data: return jsonify({"error": "Invalid JSON payload"}), 400

        query_item = data.get('queryItem')
        search_list = data.get('searchList')

        if not all([query_item, search_list]):
            return jsonify({"error": "queryItem and searchList are required."}), 400

        # === STAGE 1: FAST RETRIEVAL (using Bi-Encoder) ===
        print(f"--- Stage 1: Retrieving top candidates from {len(search_list)} items... ---")

        initial_candidates = []
        query_text_emb = np.array(query_item['text_embedding'])

        for item in search_list:
            text_emb_found = np.array(item['text_embedding'])
            text_score = logic.cosine_similarity(query_text_emb, text_emb_found)
            initial_candidates.append({"item": item, "initial_score": text_score})

        # Sort by the initial score and keep the best ones
        initial_candidates.sort(key=lambda x: x["initial_score"], reverse=True)
        top_candidates = initial_candidates[:TOP_N_CANDIDATES]
        print(f"--- Found {len(top_candidates)} candidates for re-ranking. ---")

        # === STAGE 2: ACCURATE RE-RANKING (using Cross-Encoder) ===
        if not top_candidates:
            print("βœ… No potential matches found in Stage 1.")
            return jsonify({"matches": []}), 200

        print(f"\n--- Stage 2: Re-ranking top {len(top_candidates)} candidates... ---")
        query_description = query_item['objectDescription']

        # Create pairs of [query, candidate_description] for the cross-encoder
        rerank_pairs = [(query_description, cand['item']['objectDescription']) for cand in top_candidates]

        # Get new, highly accurate scores from the cross-encoder
        cross_encoder_scores = models['cross_encoder'].predict(rerank_pairs)

        # Now, build the final results with the new scores
        final_results = []
        for i, candidate_data in enumerate(top_candidates):
            item = candidate_data['item']
            cross_score = cross_encoder_scores[i] # Get the new text score
            print(f"\n [Re-Ranking] Item ID: {item.get('_id')}")
            print(f" - Cross-Encoder Score: {cross_score:.4f}")

            # Now we calculate the final image and combined score, just like before
            has_query_image = 'shape_features' in query_item and query_item['shape_features']
            has_item_image = 'shape_features' in item and item['shape_features']

            if has_query_image and has_item_image:
                from pipeline import FEATURE_WEIGHTS
                query_shape = np.array(query_item['shape_features'])
                query_color = np.array(query_item['color_features']).astype("float32")
                query_texture = np.array(query_item['texture_features']).astype("float32")
                found_shape = np.array(item['shape_features'])
                found_color = np.array(item['color_features']).astype("float32")
                found_texture = np.array(item['texture_features']).astype("float32")
                shape_dist = cv2.matchShapes(query_shape, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
                shape_score = 1.0 / (1.0 + shape_dist)
                color_score = cv2.compareHist(query_color, found_color, cv2.HISTCMP_CORREL)
                texture_score = cv2.compareHist(query_texture, found_texture, cv2.HISTCMP_CORREL)
                raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
                                   FEATURE_WEIGHTS["color"] * color_score +
                                   FEATURE_WEIGHTS["texture"] * texture_score)
                image_score = logic.stretch_image_score(raw_image_score)
                final_score = 0.4 * image_score + 0.6 * cross_score
                print(f" - Image Score: {image_score:.4f} | Final Re-ranked Score: {final_score:.4f}")
            else:
                final_score = cross_score

            from pipeline import FINAL_SCORE_THRESHOLD
            if final_score >= FINAL_SCORE_THRESHOLD:
                print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
                final_results.append({
                    "_id": item.get('_id'),
                    "score": round(final_score, 4),
                    "objectName": item.get("objectName"),
                    "objectDescription": item.get("objectDescription"),
                    "objectImage": item.get("objectImage"),
                })
            else:
                print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")

        final_results.sort(key=lambda x: x["score"], reverse=True)
        print(f"\nβœ… Search complete. Found {len(final_results)} final matches after re-ranking.")
        print("="*50)
        return jsonify({"matches": final_results}), 200

    except Exception as e:
        print(f"❌ Error in /compare: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500