fix
Browse files
app.py
CHANGED
@@ -27,8 +27,8 @@ FEATURE_WEIGHTS = {
|
|
27 |
# ---- Load Models ----
|
28 |
print("="*50)
|
29 |
print("π Initializing application and loading models...")
|
30 |
-
|
31 |
-
device = torch.device('cpu')
|
32 |
print(f"π§ Using device: {device}")
|
33 |
|
34 |
print("...Loading Grounding DINO model...")
|
@@ -118,12 +118,7 @@ def extract_features(segmented_image: Image.Image) -> dict:
|
|
118 |
|
119 |
color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
|
120 |
cv2.normalize(color_hist, color_hist)
|
121 |
-
|
122 |
-
# ------------------ THE FIX IS HERE ------------------
|
123 |
-
# The color_hist is multi-dimensional. We must flatten it to a 1D array
|
124 |
-
# before converting it to a list for the JSON response.
|
125 |
-
flat_color_hist = color_hist.flatten()
|
126 |
-
# ----------------------------------------------------
|
127 |
|
128 |
gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
|
129 |
lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
|
@@ -133,7 +128,7 @@ def extract_features(segmented_image: Image.Image) -> dict:
|
|
133 |
|
134 |
return {
|
135 |
"shape_features": hu_moments.tolist(),
|
136 |
-
"color_features":
|
137 |
"texture_features": texture_hist.tolist()
|
138 |
}
|
139 |
|
@@ -168,11 +163,12 @@ def process_item():
|
|
168 |
|
169 |
object_name = data.get('objectName')
|
170 |
description = data.get('objectDescription')
|
171 |
-
image_url = data.get('objectImage')
|
172 |
|
173 |
if not all([object_name, description]):
|
174 |
return jsonify({"error": "objectName and objectDescription are required."}), 400
|
175 |
|
|
|
176 |
canonical_label = get_canonical_label(object_name)
|
177 |
text_embedding = get_text_embedding(description)
|
178 |
|
@@ -181,11 +177,13 @@ def process_item():
|
|
181 |
"text_embedding": text_embedding,
|
182 |
}
|
183 |
|
|
|
184 |
if image_url:
|
185 |
print("--- Image URL provided, processing visual features... ---")
|
186 |
image = download_image_from_url(image_url)
|
187 |
object_crop = detect_and_crop(image, canonical_label)
|
188 |
visual_features = extract_features(object_crop)
|
|
|
189 |
response_data.update(visual_features)
|
190 |
else:
|
191 |
print("--- No image URL provided, skipping visual feature extraction. ---")
|
@@ -215,12 +213,9 @@ def compare_items():
|
|
215 |
return jsonify({"error": "queryItem and searchList are required."}), 400
|
216 |
|
217 |
query_text_emb = np.array(query_item['text_embedding'])
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
query_shape_feat = np.array(query_item['shape_features'])
|
222 |
-
query_color_feat = np.array(query_item['color_features']).astype("float32")
|
223 |
-
query_texture_feat = np.array(query_item['texture_features']).astype("float32")
|
224 |
|
225 |
results = []
|
226 |
print(f"--- Comparing 1 query item against {len(search_list)} items ---")
|
@@ -233,29 +228,24 @@ def compare_items():
|
|
233 |
text_score = cosine_similarity(query_text_emb, text_emb_found)
|
234 |
print(f" - Text Score: {text_score:.4f}")
|
235 |
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
shape_score = 1.0 / (1.0 + shape_dist)
|
248 |
-
|
249 |
-
color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
|
250 |
-
texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
|
251 |
-
|
252 |
-
image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
|
253 |
-
FEATURE_WEIGHTS["color"] * color_score +
|
254 |
-
FEATURE_WEIGHTS["texture"] * texture_score)
|
255 |
-
|
256 |
-
final_score = 0.4 * image_score + 0.6 * text_score
|
257 |
-
print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
|
258 |
|
|
|
|
|
259 |
results.append({
|
260 |
"_id": item_id,
|
261 |
"score": round(final_score, 4),
|
@@ -279,4 +269,4 @@ def compare_items():
|
|
279 |
return jsonify({"error": str(e)}), 500
|
280 |
|
281 |
if __name__ == '__main__':
|
282 |
-
app.run(host='0.0.0.0', port=7860)
|
|
|
27 |
# ---- Load Models ----
|
28 |
print("="*50)
|
29 |
print("π Initializing application and loading models...")
|
30 |
+
device_name = os.environ.get("device", "cpu")
|
31 |
+
device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu')
|
32 |
print(f"π§ Using device: {device}")
|
33 |
|
34 |
print("...Loading Grounding DINO model...")
|
|
|
118 |
|
119 |
color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
|
120 |
cv2.normalize(color_hist, color_hist)
|
121 |
+
color_hist = color_hist.flatten()
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
|
124 |
lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
|
|
|
128 |
|
129 |
return {
|
130 |
"shape_features": hu_moments.tolist(),
|
131 |
+
"color_features": color_hist.tolist(),
|
132 |
"texture_features": texture_hist.tolist()
|
133 |
}
|
134 |
|
|
|
163 |
|
164 |
object_name = data.get('objectName')
|
165 |
description = data.get('objectDescription')
|
166 |
+
image_url = data.get('objectImage') # This can now be null
|
167 |
|
168 |
if not all([object_name, description]):
|
169 |
return jsonify({"error": "objectName and objectDescription are required."}), 400
|
170 |
|
171 |
+
# --- Always process text-based features ---
|
172 |
canonical_label = get_canonical_label(object_name)
|
173 |
text_embedding = get_text_embedding(description)
|
174 |
|
|
|
177 |
"text_embedding": text_embedding,
|
178 |
}
|
179 |
|
180 |
+
# --- Process visual features ONLY if an image_url is provided ---
|
181 |
if image_url:
|
182 |
print("--- Image URL provided, processing visual features... ---")
|
183 |
image = download_image_from_url(image_url)
|
184 |
object_crop = detect_and_crop(image, canonical_label)
|
185 |
visual_features = extract_features(object_crop)
|
186 |
+
# Add visual features to the response
|
187 |
response_data.update(visual_features)
|
188 |
else:
|
189 |
print("--- No image URL provided, skipping visual feature extraction. ---")
|
|
|
213 |
return jsonify({"error": "queryItem and searchList are required."}), 400
|
214 |
|
215 |
query_text_emb = np.array(query_item['text_embedding'])
|
216 |
+
query_shape_feat = np.array(query_item['shape_features'])
|
217 |
+
query_color_feat = np.array(query_item['color_features']).astype("float32")
|
218 |
+
query_texture_feat = np.array(query_item['texture_features']).astype("float32")
|
|
|
|
|
|
|
219 |
|
220 |
results = []
|
221 |
print(f"--- Comparing 1 query item against {len(search_list)} items ---")
|
|
|
228 |
text_score = cosine_similarity(query_text_emb, text_emb_found)
|
229 |
print(f" - Text Score: {text_score:.4f}")
|
230 |
|
231 |
+
found_shape = np.array(item['shape_features'])
|
232 |
+
found_color = np.array(item['color_features']).astype("float32")
|
233 |
+
found_texture = np.array(item['texture_features']).astype("float32")
|
234 |
+
|
235 |
+
shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
|
236 |
+
shape_score = 1.0 / (1.0 + shape_dist)
|
237 |
|
238 |
+
color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
|
239 |
+
texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
|
240 |
+
|
241 |
+
image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
|
242 |
+
FEATURE_WEIGHTS["color"] * color_score +
|
243 |
+
FEATURE_WEIGHTS["texture"] * texture_score)
|
244 |
+
|
245 |
+
final_score = 0.4 * image_score + 0.6 * text_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
|
248 |
+
|
249 |
results.append({
|
250 |
"_id": item_id,
|
251 |
"score": round(final_score, 4),
|
|
|
269 |
return jsonify({"error": str(e)}), 500
|
270 |
|
271 |
if __name__ == '__main__':
|
272 |
+
app.run(host='0.0.0.0', port=7860)
|