import sys sys.stdout.reconfigure(line_buffering=True) import os import numpy as np import requests import cv2 from skimage import feature from io import BytesIO import traceback from flask import Flask, request, jsonify from PIL import Image # import deep learning libraries import torch from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, AutoTokenizer, AutoModel from segment_anything import SamPredictor, sam_model_registry app = Flask(__name__) # sum = 1 FEATURE_WEIGHTS = { "shape": 0.4, "color": 0.5, "texture": 0.1 } # threshold FINAL_SCORE_THRESHOLD = 0.5 # load all models print("="*50) print("šŸš€ Initializing application and loading models...") device_name = os.environ.get("device", "cpu") device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu') print(f"🧠 Using device: {device}") print("...Loading Grounding DINO model...") # --- ā¬‡ļø CORRECTED MODEL ID ā¬‡ļø --- gnd_model_id = "IDEA-Research/grounding-dino-base" processor_gnd = AutoProcessor.from_pretrained(gnd_model_id) model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device) print("...Loading Segment Anything (SAM) model...") sam_checkpoint = "sam_vit_b_01ec64.pth" sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device) predictor = SamPredictor(sam_model) print("...Loading BGE model for text embeddings...") # --- ā¬‡ļø UPGRADED MODEL ā¬‡ļø --- bge_model_id = "BAAI/bge-large-en-v1.5" tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id) model_text = AutoModel.from_pretrained(bge_model_id).to(device) print("āœ… All models loaded successfully.") print("="*50) # helper functions def get_canonical_label(object_name_phrase: str) -> str: print(f"\n [Label] Extracting label for: '{object_name_phrase}'") label = object_name_phrase.strip().lower().split()[-1] label = ''.join(filter(str.isalpha, label)) print(f" [Label] āœ… Extracted label: '{label}'") return label if label else "unknown" def download_image_from_url(image_url: str) -> Image.Image: print(f" [Download] Downloading image from: {image_url[:80]}...") response = requests.get(image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)) image_rgb = image.convert("RGB") print(" [Download] āœ… Image downloaded and standardized to RGB.") return image_rgb def detect_and_crop(image: Image.Image, object_name: str) -> Image.Image: print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'") image_np = np.array(image.convert("RGB")) height, width = image_np.shape[:2] prompt = [[f"a {object_name}"]] inputs = processor_gnd(images=image, text=prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model_gnd(**inputs) results = processor_gnd.post_process_grounded_object_detection( outputs, inputs.input_ids, threshold=0.4, text_threshold=0.3, target_sizes=[(height, width)] ) if not results or len(results[0]['boxes']) == 0: print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.") return image result = results[0] scores = result['scores'] max_idx = int(torch.argmax(scores)) box = result['boxes'][max_idx].cpu().numpy().astype(int) print(f" [Detect & Crop] āœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}") x1, y1, x2, y2 = box predictor.set_image(image_np) box_prompt = np.array([[x1, y1, x2, y2]]) masks, _, _ = predictor.predict(box=box_prompt, multimask_output=False) mask = masks[0] mask_bool = mask > 0 cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8) cropped_img_rgba[:, :, :3] = image_np cropped_img_rgba[:, :, 3] = mask_bool * 255 cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2] object_image = Image.fromarray(cropped_img_rgba, 'RGBA') return object_image def extract_features(segmented_image: Image.Image) -> dict: image_rgba = np.array(segmented_image) if image_rgba.shape[2] != 4: raise ValueError("Segmented image must be RGBA") b, g, r, a = cv2.split(image_rgba) image_rgb = cv2.merge((b, g, r)) mask = a gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7) color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256]) cv2.normalize(color_hist, color_hist) color_hist = color_hist.flatten() gray_masked = cv2.bitwise_and(gray, gray, mask=mask) lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform") (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26)) texture_hist = texture_hist.astype("float32") texture_hist /= (texture_hist.sum() + 1e-6) return { "shape_features": hu_moments.tolist(), "color_features": color_hist.tolist(), "texture_features": texture_hist.tolist() } def get_text_embedding(text: str) -> list: print(f" [Embedding] Generating text embedding for: '{text[:50]}...'") text_with_instruction = f"Represent this sentence for searching relevant passages: {text}" inputs = tokenizer_text(text_with_instruction, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) with torch.no_grad(): outputs = model_text(**inputs) embedding = outputs.last_hidden_state[:, 0, :] embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) print(" [Embedding] āœ… Text embedding generated.") return embedding.cpu().numpy()[0].tolist() def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float: return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))) # API endpoints @app.route('/process', methods=['POST']) def process_item(): """ Receives item details, processes them, and returns all computed features. This is called when a new item is created in the Node.js backend. """ print("\n" + "="*50) print("āž” [Request] Received new request to /process") try: data = request.get_json() if not data: return jsonify({"error": "Invalid JSON payload"}), 400 object_name = data.get('objectName') description = data.get('objectDescription') image_url = data.get('objectImage') # This can now be null if not all([object_name, description]): return jsonify({"error": "objectName and objectDescription are required."}), 400 # process text based features canonical_label = get_canonical_label(object_name) text_embedding = get_text_embedding(description) response_data = { "canonicalLabel": canonical_label, "text_embedding": text_embedding, } # process visual features ONLY if an image_url is provided if image_url: print("--- Image URL provided, processing visual features... ---") image = download_image_from_url(image_url) object_crop = detect_and_crop(image, canonical_label) visual_features = extract_features(object_crop) # Add visual features to the response response_data.update(visual_features) else: print("--- No image URL provided, skipping visual feature extraction. ---") print("āœ… Successfully processed item.") print("="*50) return jsonify(response_data), 200 except Exception as e: print(f"āŒ Error in /process: {e}") traceback.print_exc() return jsonify({"error": str(e)}), 500 def stretch_image_score(score): if score < 0.4 or score == 1.0: return score # increase confidence return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4) @app.route('/compare', methods=['POST']) def compare_items(): print("\n" + "="*50) print("āž” [Request] Received new request to /compare") try: data = request.get_json() if not data: return jsonify({"error": "Invalid JSON payload"}), 400 query_item = data.get('queryItem') search_list = data.get('searchList') if not all([query_item, search_list]): return jsonify({"error": "queryItem and searchList are required."}), 400 query_text_emb = np.array(query_item['text_embedding']) results = [] print(f"--- Comparing 1 query item against {len(search_list)} items ---") for item in search_list: item_id = item.get('_id') print(f"\n [Checking] Item ID: {item_id}") try: # Text comparison is always done text_emb_found = np.array(item['text_embedding']) text_score = cosine_similarity(query_text_emb, text_emb_found) print(f" - Text Score: {text_score:.4f}") # --- NEW: Check if BOTH items have visual features --- has_query_image = 'shape_features' in query_item and query_item['shape_features'] has_item_image = 'shape_features' in item and item['shape_features'] if has_query_image and has_item_image: print(" - Both items have images. Performing visual comparison.") # If both have images, proceed with full comparison query_shape_feat = np.array(query_item['shape_features']) query_color_feat = np.array(query_item['color_features']).astype("float32") query_texture_feat = np.array(query_item['texture_features']).astype("float32") found_shape = np.array(item['shape_features']) found_color = np.array(item['color_features']).astype("float32") found_texture = np.array(item['texture_features']).astype("float32") shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0) shape_score = 1.0 / (1.0 + shape_dist) color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL) texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL) raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score + FEATURE_WEIGHTS["color"] * color_score + FEATURE_WEIGHTS["texture"] * texture_score) print(f" - Raw Image Score: {raw_image_score:.4f}") image_score = stretch_image_score(raw_image_score) # Weighted average of image and text scores final_score = 0.4 * image_score + 0.6 * text_score print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}") else: # If one or both items lack an image, the final score is JUST the text score print(" - One or both items missing image. Using text score only.") final_score = text_score # Check if the final score meets the threshold if final_score >= FINAL_SCORE_THRESHOLD: print(f" - āœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})") results.append({ "_id": item_id, "score": round(final_score, 4), "objectName": item.get("objectName"), "objectDescription": item.get("objectDescription"), "objectImage": item.get("objectImage"), }) else: print(f" - āŒ REJECTED (Score < {FINAL_SCORE_THRESHOLD})") except Exception as e: print(f" [Skipping] Item {item_id} due to processing error: {e}") continue results.sort(key=lambda x: x["score"], reverse=True) print(f"\nāœ… Search complete. Found {len(results)} potential matches.") print("="*50) return jsonify({"matches": results}), 200 except Exception as e: print(f"āŒ Error in /compare: {e}") traceback.print_exc() return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)