Soham Kolte commited on
Commit
fd33333
Β·
1 Parent(s): 3c63ff6

Initial commit of AI pipeline

Browse files
Files changed (4) hide show
  1. Dockerfile +24 -0
  2. app.py +271 -0
  3. requirements.txt +9 -0
  4. sam_vit_b_01ec64.pth +3 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /code
6
+
7
+ # Install system dependencies required by OpenCV
8
+ RUN apt-get update && apt-get install -y \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy the requirements file into the container
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ # Install the Python packages
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ # Copy the rest of your application files into the container
20
+ COPY . /code/
21
+
22
+ # Tell Gunicorn to run your app on the port Hugging Face expects (7860)
23
+ # The --timeout flag prevents the server from crashing during long model-loading times.
24
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--timeout", "300", "app:app"]
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import requests
4
+ import cv2
5
+ from skimage import feature
6
+ from io import BytesIO
7
+ import traceback
8
+
9
+ from flask import Flask, request, jsonify
10
+ from PIL import Image
11
+
12
+ # ---- Import deep learning libraries for models ----
13
+ import torch
14
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, AutoTokenizer, AutoModel
15
+ from segment_anything import SamPredictor, sam_model_registry
16
+
17
+ # ---- Configuration ----
18
+ app = Flask(__name__)
19
+
20
+ # Weights for combining feature scores. They must sum to 1.0
21
+ FEATURE_WEIGHTS = {
22
+ "shape": 0.5,
23
+ "color": 0.25,
24
+ "texture": 0.25
25
+ }
26
+
27
+ # ---- Load Models ----
28
+ print("="*50)
29
+ print("πŸš€ Initializing application and loading models...")
30
+ device_name = os.environ.get("device", "cpu")
31
+ device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu')
32
+ print(f"🧠 Using device: {device}")
33
+
34
+ print("...Loading Grounding DINO model...")
35
+ gnd_model_id = "IDEA-Research/grounding-dino-tiny"
36
+ processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
37
+ model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
38
+
39
+ print("...Loading Segment Anything (SAM) model...")
40
+ sam_checkpoint = "sam_vit_b_01ec64.pth"
41
+ sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
42
+ predictor = SamPredictor(sam_model)
43
+
44
+ print("...Loading BGE model for text embeddings...")
45
+ bge_model_id = "BAAI/bge-small-en-v1.5"
46
+ tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
47
+ model_text = AutoModel.from_pretrained(bge_model_id).to(device)
48
+ print("βœ… All models loaded successfully.")
49
+ print("="*50)
50
+
51
+
52
+ # ---- Helper Functions ----
53
+
54
+ def get_canonical_label(object_name_phrase: str) -> str:
55
+ print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
56
+ label = object_name_phrase.strip().lower().split()[-1]
57
+ label = ''.join(filter(str.isalpha, label))
58
+ print(f" [Label] βœ… Extracted label: '{label}'")
59
+ return label if label else "unknown"
60
+
61
+ def download_image_from_url(image_url: str) -> Image.Image:
62
+ print(f" [Download] Downloading image from: {image_url[:80]}...")
63
+ response = requests.get(image_url)
64
+ response.raise_for_status()
65
+ image = Image.open(BytesIO(response.content))
66
+ print(" [Download] βœ… Image downloaded successfully.")
67
+ return image
68
+
69
+ def detect_and_crop(image: Image.Image, object_name: str) -> Image.Image:
70
+ print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
71
+ image_np = np.array(image.convert("RGB"))
72
+ height, width = image_np.shape[:2]
73
+ prompt = [[f"a {object_name}"]]
74
+ inputs = processor_gnd(images=image, text=prompt, return_tensors="pt").to(device)
75
+ with torch.no_grad():
76
+ outputs = model_gnd(**inputs)
77
+ results = processor_gnd.post_process_grounded_object_detection(
78
+ outputs, inputs.input_ids, box_threshold=0.4, text_threshold=0.3, target_sizes=[(height, width)]
79
+ )
80
+ if not results or len(results[0]['boxes']) == 0:
81
+ print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
82
+ return image
83
+ result = results[0]
84
+ scores = result['scores']
85
+ max_idx = int(torch.argmax(scores))
86
+ box = result['boxes'][max_idx].cpu().numpy().astype(int)
87
+ print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
88
+ x1, y1, x2, y2 = box
89
+
90
+ predictor.set_image(image_np)
91
+ box_prompt = np.array([[x1, y1, x2, y2]])
92
+ masks, _, _ = predictor.predict(box=box_prompt, multimask_output=False)
93
+ mask = masks[0]
94
+
95
+ mask_bool = mask > 0
96
+ cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
97
+ cropped_img_rgba[:, :, :3] = image_np
98
+ cropped_img_rgba[:, :, 3] = mask_bool * 255
99
+
100
+ cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
101
+
102
+ object_image = Image.fromarray(cropped_img_rgba, 'RGBA')
103
+ return object_image
104
+
105
+ def extract_features(segmented_image: Image.Image) -> dict:
106
+ image_rgba = np.array(segmented_image)
107
+ if image_rgba.shape[2] != 4:
108
+ raise ValueError("Segmented image must be RGBA")
109
+
110
+ b, g, r, a = cv2.split(image_rgba)
111
+ image_rgb = cv2.merge((b, g, r))
112
+ mask = a
113
+
114
+ gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
115
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116
+ hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
117
+
118
+ color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
119
+ cv2.normalize(color_hist, color_hist)
120
+ color_hist = color_hist.flatten()
121
+
122
+ gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
123
+ lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
124
+ (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
125
+ texture_hist = texture_hist.astype("float32")
126
+ texture_hist /= (texture_hist.sum() + 1e-6)
127
+
128
+ return {
129
+ "shape_features": hu_moments.tolist(),
130
+ "color_features": color_hist.tolist(),
131
+ "texture_features": texture_hist.tolist()
132
+ }
133
+
134
+ def get_text_embedding(text: str) -> list:
135
+ print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
136
+ text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"
137
+ inputs = tokenizer_text(text_with_instruction, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
138
+ with torch.no_grad():
139
+ outputs = model_text(**inputs)
140
+ embedding = outputs.last_hidden_state[:, 0, :]
141
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
142
+ print(" [Embedding] βœ… Text embedding generated.")
143
+ return embedding.cpu().numpy()[0].tolist()
144
+
145
+ def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
146
+ return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
147
+
148
+ # ---- API Endpoints ----
149
+
150
+ @app.route('/process', methods=['POST'])
151
+ def process_item():
152
+ """
153
+ Receives item details, processes them, and returns all computed features.
154
+ This is called when a new item is created in the Node.js backend.
155
+ """
156
+ print("\n" + "="*50)
157
+ print("➑ [Request] Received new request to /process")
158
+ try:
159
+ data = request.get_json()
160
+ if not data:
161
+ return jsonify({"error": "Invalid JSON payload"}), 400
162
+
163
+ object_name = data.get('objectName')
164
+ description = data.get('objectDescription')
165
+ image_url = data.get('objectImage') # This can now be null
166
+
167
+ if not all([object_name, description]):
168
+ return jsonify({"error": "objectName and objectDescription are required."}), 400
169
+
170
+ # --- Always process text-based features ---
171
+ canonical_label = get_canonical_label(object_name)
172
+ text_embedding = get_text_embedding(description)
173
+
174
+ response_data = {
175
+ "canonicalLabel": canonical_label,
176
+ "text_embedding": text_embedding,
177
+ }
178
+
179
+ # --- Process visual features ONLY if an image_url is provided ---
180
+ if image_url:
181
+ print("--- Image URL provided, processing visual features... ---")
182
+ image = download_image_from_url(image_url)
183
+ object_crop = detect_and_crop(image, canonical_label)
184
+ visual_features = extract_features(object_crop)
185
+ # Add visual features to the response
186
+ response_data.update(visual_features)
187
+ else:
188
+ print("--- No image URL provided, skipping visual feature extraction. ---")
189
+
190
+ print("βœ… Successfully processed item.")
191
+ print("="*50)
192
+ return jsonify(response_data), 200
193
+
194
+ except Exception as e:
195
+ print(f"❌ Error in /process: {e}")
196
+ traceback.print_exc()
197
+ return jsonify({"error": str(e)}), 500
198
+
199
+ @app.route('/compare', methods=['POST'])
200
+ def compare_items():
201
+ print("\n" + "="*50)
202
+ print("➑ [Request] Received new request to /compare")
203
+ try:
204
+ data = request.get_json()
205
+ if not data:
206
+ return jsonify({"error": "Invalid JSON payload"}), 400
207
+
208
+ query_item = data.get('queryItem')
209
+ search_list = data.get('searchList')
210
+
211
+ if not all([query_item, search_list]):
212
+ return jsonify({"error": "queryItem and searchList are required."}), 400
213
+
214
+ query_text_emb = np.array(query_item['text_embedding'])
215
+ query_shape_feat = np.array(query_item['shape_features'])
216
+ query_color_feat = np.array(query_item['color_features']).astype("float32")
217
+ query_texture_feat = np.array(query_item['texture_features']).astype("float32")
218
+
219
+ results = []
220
+ print(f"--- Comparing 1 query item against {len(search_list)} items ---")
221
+
222
+ for item in search_list:
223
+ item_id = item.get('_id')
224
+ print(f"\n [Checking] Item ID: {item_id}")
225
+ try:
226
+ text_emb_found = np.array(item['text_embedding'])
227
+ text_score = cosine_similarity(query_text_emb, text_emb_found)
228
+ print(f" - Text Score: {text_score:.4f}")
229
+
230
+ found_shape = np.array(item['shape_features'])
231
+ found_color = np.array(item['color_features']).astype("float32")
232
+ found_texture = np.array(item['texture_features']).astype("float32")
233
+
234
+ shape_dist = cv2.matchShapes(query_shape_feat, found_shape, cv2.CONTOURS_MATCH_I1, 0.0)
235
+ shape_score = 1.0 / (1.0 + shape_dist)
236
+
237
+ color_score = cv2.compareHist(query_color_feat, found_color, cv2.HISTCMP_CORREL)
238
+ texture_score = cv2.compareHist(query_texture_feat, found_texture, cv2.HISTCMP_CORREL)
239
+
240
+ image_score = (FEATURE_WEIGHTS["shape"] * shape_score +
241
+ FEATURE_WEIGHTS["color"] * color_score +
242
+ FEATURE_WEIGHTS["texture"] * texture_score)
243
+
244
+ final_score = 0.6 * image_score + 0.4 * text_score
245
+
246
+ print(f" - Image Score: {image_score:.4f} | Final Score: {final_score:.4f}")
247
+
248
+ results.append({
249
+ "_id": item_id,
250
+ "score": round(final_score, 4),
251
+ "objectName": item.get("objectName"),
252
+ "objectDescription": item.get("objectDescription"),
253
+ "objectImage": item.get("objectImage"),
254
+ })
255
+
256
+ except Exception as e:
257
+ print(f" [Skipping] Item {item_id} due to error: {e}")
258
+ continue
259
+
260
+ results.sort(key=lambda x: x["score"], reverse=True)
261
+ print(f"\nβœ… Search complete. Found {len(results)} potential matches.")
262
+ print("="*50)
263
+ return jsonify({"matches": results}), 200
264
+
265
+ except Exception as e:
266
+ print(f"❌ Error in /compare: {e}")
267
+ traceback.print_exc()
268
+ return jsonify({"error": str(e)}), 500
269
+
270
+ if __name__ == '__main__':
271
+ app.run(host='0.0.0.0', port=7860)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ transformers
4
+ opencv-python-headless
5
+ scikit-image
6
+ Pillow
7
+ segment-anything-py
8
+ requests
9
+ gunicorn
sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383