sohamnk commited on
Commit
97d00a5
Β·
verified Β·
1 Parent(s): 7a4305a
Files changed (1) hide show
  1. app.py +1 -296
app.py CHANGED
@@ -1,299 +1,4 @@
1
- import os
2
- import numpy as np
3
- import requests
4
- import cv2
5
- from skimage import feature
6
- from io import BytesIO
7
- import traceback
8
 
9
- from flask import Flask, request, jsonify
10
- from PIL import Image
11
-
12
- # import deep learning libraries
13
- import torch
14
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, AutoTokenizer, AutoModel
15
- from segment_anything import SamPredictor, sam_model_registry
16
-
17
- app = Flask(__name__)
18
-
19
- # sum = 1
20
- FEATURE_WEIGHTS = {
21
- "shape": 0.4,
22
- "color": 0.5,
23
- "texture": 0.1
24
- }
25
-
26
- # threshold
27
- FINAL_SCORE_THRESHOLD = 0.5
28
-
29
-
30
- # load all models
31
- print("="*50)
32
- print("πŸš€ Initializing application and loading models...")
33
- device_name = os.environ.get("device", "cpu")
34
- device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu')
35
- print(f"🧠 Using device: {device}")
36
-
37
- print("...Loading Grounding DINO model...")
38
- gnd_model_id = "IDEA-Research/grounding-dino-tiny"
39
- processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
40
- model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
41
-
42
- print("...Loading Segment Anything (SAM) model...")
43
- sam_checkpoint = "sam_vit_b_01ec64.pth"
44
- sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
45
- predictor = SamPredictor(sam_model)
46
-
47
- print("...Loading BGE model for text embeddings...")
48
- bge_model_id = "BAAI/bge-small-en-v1.5"
49
- tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
50
- model_text = AutoModel.from_pretrained(bge_model_id).to(device)
51
- print("βœ… All models loaded successfully.")
52
- print("="*50)
53
-
54
-
55
- # helper functions
56
-
57
- def get_canonical_label(object_name_phrase: str) -> str:
58
- print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
59
- label = object_name_phrase.strip().lower().split()[-1]
60
- label = ''.join(filter(str.isalpha, label))
61
- print(f" [Label] βœ… Extracted label: '{label}'")
62
- return label if label else "unknown"
63
-
64
- def download_image_from_url(image_url: str) -> Image.Image:
65
- print(f" [Download] Downloading image from: {image_url[:80]}...")
66
- response = requests.get(image_url)
67
- response.raise_for_status()
68
- image = Image.open(BytesIO(response.content))
69
- image_rgb = image.convert("RGB")
70
- print(" [Download] βœ… Image downloaded and standardized to RGB.")
71
- return image_rgb
72
-
73
- def detect_and_crop(image: Image.Image, object_name: str) -> Image.Image:
74
- print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
75
- image_np = np.array(image.convert("RGB"))
76
- height, width = image_np.shape[:2]
77
- prompt = [[f"a {object_name}"]]
78
- inputs = processor_gnd(images=image, text=prompt, return_tensors="pt").to(device)
79
- with torch.no_grad():
80
- outputs = model_gnd(**inputs)
81
- results = processor_gnd.post_process_grounded_object_detection(
82
- outputs, inputs.input_ids, box_threshold=0.4, text_threshold=0.3, target_sizes=[(height, width)]
83
- )
84
- if not results or len(results[0]['boxes']) == 0:
85
- print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
86
- return image
87
- result = results[0]
88
- scores = result['scores']
89
- max_idx = int(torch.argmax(scores))
90
- box = result['boxes'][max_idx].cpu().numpy().astype(int)
91
- print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
92
- x1, y1, x2, y2 = box
93
-
94
- predictor.set_image(image_np)
95
- box_prompt = np.array([[x1, y1, x2, y2]])
96
- masks, _, _ = predictor.predict(box=box_prompt, multimask_output=False)
97
- mask = masks[0]
98
-
99
- mask_bool = mask > 0
100
- cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
101
- cropped_img_rgba[:, :, :3] = image_np
102
- cropped_img_rgba[:, :, 3] = mask_bool * 255
103
-
104
- cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
105
-
106
- object_image = Image.fromarray(cropped_img_rgba, 'RGBA')
107
- return object_image
108
-
109
- def extract_features(segmented_image: Image.Image) -> dict:
110
- image_rgba = np.array(segmented_image)
111
- if image_rgba.shape[2] != 4:
112
- raise ValueError("Segmented image must be RGBA")
113
-
114
- b, g, r, a = cv2.split(image_rgba)
115
- image_rgb = cv2.merge((b, g, r))
116
- mask = a
117
-
118
- gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
119
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
120
- hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
121
-
122
- color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
123
- cv2.normalize(color_hist, color_hist)
124
- color_hist = color_hist.flatten()
125
-
126
- gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
127
- lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
128
- (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
129
- texture_hist = texture_hist.astype("float32")
130
- texture_hist /= (texture_hist.sum() + 1e-6)
131
-
132
- return {
133
- "shape_features": hu_moments.tolist(),
134
- "color_features": color_hist.tolist(),
135
- "texture_features": texture_hist.tolist()
136
- }
137
-
138
- def get_text_embedding(text: str) -> list:
139
- print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
140
- text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"
141
- inputs = tokenizer_text(text_with_instruction, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
142
- with torch.no_grad():
143
- outputs = model_text(**inputs)
144
- embedding = outputs.last_hidden_state[:, 0, :]
145
- embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
146
- print(" [Embedding] βœ… Text embedding generated.")
147
- return embedding.cpu().numpy()[0].tolist()
148
-
149
- def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
150
- return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
151
-
152
- # API endpoints
153
-
154
- @app.route('/process', methods=['POST'])
155
- def process_item():
156
- """
157
- Receives item details, processes them, and returns all computed features.
158
- This is called when a new item is created in the Node.js backend.
159
- """
160
- print("\n" + "="*50)
161
- print("➑ [Request] Received new request to /process")
162
- try:
163
- data = request.get_json()
164
- if not data:
165
- return jsonify({"error": "Invalid JSON payload"}), 400
166
-
167
- object_name = data.get('objectName')
168
- description = data.get('objectDescription')
169
- image_url = data.get('objectImage') # This can now be null
170
-
171
- if not all([object_name, description]):
172
- return jsonify({"error": "objectName and objectDescription are required."}), 400
173
-
174
- # process text based features
175
- canonical_label = get_canonical_label(object_name)
176
- text_embedding = get_text_embedding(description)
177
-
178
- response_data = {
179
- "canonicalLabel": canonical_label,
180
- "text_embedding": text_embedding,
181
- }
182
-
183
- # process visual features ONLY if an image_url is provided
184
- if image_url:
185
- print("--- Image URL provided, processing visual features... ---")
186
- image = download_image_from_url(image_url)
187
- object_crop = detect_and_crop(image, canonical_label)
188
- visual_features = extract_features(object_crop)
189
- # Add visual features to the response
190
- response_data.update(visual_features)
191
- else:
192
- print("--- No image URL provided, skipping visual feature extraction. ---")
193
-
194
- print("βœ… Successfully processed item.")
195
- print("="*50)
196
- return jsonify(response_data), 200
197
-
198
- except Exception as e:
199
- print(f"❌ Error in /process: {e}")
200
- traceback.print_exc()
201
- return jsonify({"error": str(e)}), 500
202
-
203
- def stretch_image_score(score):
204
- if score < 0.4 or score == 1.0:
205
- return score
206
- # increase confidence
207
- return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)
208
-
209
- @app.route('/compare', methods=['POST'])
210
- def compare_items():
211
- print("\n" + "="*50)
212
- print("➑ [Request] Received new request to /compare")
213
- try:
214
- data = request.get_json()
215
- if not data:
216
- return jsonify({"error": "Invalid JSON payload"}), 400
217
-
218
- query_item = data.get('queryItem')
219
- search_list = data.get('searchList')
220
-
221
- if not all([query_item, search_list]):
222
- return jsonify({"error": "queryItem and searchList are required."}), 400
223
-
224
- query_text_emb = np.array(query_item['text_embedding'])
225
- results = []
226
- print(f"--- Comparing 1 query item against {len(search_list)} items ---")
227
-
228
- for item in search_list:
229
- item_id = item.get('_id')
230
- print(f"\n [Checking] Item ID: {item_id}")
231
- try:
232
- # Text comparison is always done
233
- text_emb_found = np.array(item['text_embedding'])
234
- text_score = cosine_similarity(query_text_emb, text_emb_found)
235
- print(f" - Text Score: {text_score:.4f}")
236
-
237
- # --- NEW: Check if BOTH items have visual features ---
238
- has_query_image = 'shape_features' in query_item and query_item['shape_features']
239
- has_item_image = 'shape_features' in item and item['shape_features']
240
-
241
- if has_query_image and has_item_image:
242
- print(" - Both items have images. Performing visual comparison.")
243
- # If both have images, proceed with full comparison
244
- query_shape_feat = np.array(query_item['shape_features'])
245
- query_color_feat = np.array(query_item['color_features']).astype("float32")
246
- query_texture_feat = np.array(query_item['texture_features']).astype("float32")
247
-
248
- found_shape = np.array(item['shape_features'])
249
- found_color = np.array(item['color_features']).astype("float32")
250
- found_texture = np.array(item['texture_features']).astype("float32")
251
-
252
- shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
253
- shape_score = 1.0 / (1.0 + shape_dist)
254
- color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
255
- texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
256
-
257
- raw_image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
258
- FEATURE_WEIGHTS["color"] * color_score +
259
- FEATURE_WEIGHTS["texture"] * texture_score)
260
-
261
- image_score = stretch_image_score(raw_image_score)
262
-
263
- # Weighted average of image and text scores
264
- final_score = 0.4 * image_score + 0.6 * text_score
265
- print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
266
-
267
- else:
268
- # If one or both items lack an image, the final score is JUST the text score
269
- print(" - One or both items missing image. Using text score only.")
270
- final_score = text_score
271
-
272
- # Check if the final score meets the threshold
273
- if final_score >= FINAL_SCORE_THRESHOLD:
274
- print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
275
- results.append({
276
- "_id": item_id,
277
- "score": round(final_score, 4),
278
- "objectName": item.get("objectName"),
279
- "objectDescription": item.get("objectDescription"),
280
- "objectImage": item.get("objectImage"),
281
- })
282
- else:
283
- print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
284
-
285
- except Exception as e:
286
- print(f" [Skipping] Item {item_id} due to processing error: {e}")
287
- continue
288
-
289
- results.sort(key=lambda x: x["score"], reverse=True)
290
- print(f"\nβœ… Search complete. Found {len(results)} potential matches.")
291
- print("="*50)
292
- return jsonify({"matches": results}), 200
293
-
294
- except Exception as e:
295
- print(f"❌ Error in /compare: {e}")
296
- traceback.print_exc()
297
- return jsonify({"error": str(e)}), 500
298
  if __name__ == '__main__':
299
  app.run(host='0.0.0.0', port=7860)
 
1
+ from pipeline import app
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  if __name__ == '__main__':
4
  app.run(host='0.0.0.0', port=7860)