sohamnk commited on
Commit
2aed1b0
Β·
verified Β·
1 Parent(s): 88b35be

Update pipeline/logic.py

Browse files
Files changed (1) hide show
  1. pipeline/logic.py +50 -27
pipeline/logic.py CHANGED
@@ -3,11 +3,12 @@ import requests
3
  import cv2
4
  from skimage import feature
5
  from io import BytesIO
6
- from PIL import Image
7
  import torch
8
- from PIL import ImageFile
9
  ImageFile.LOAD_TRUNCATED_IMAGES = True
10
 
 
11
  def get_canonical_label(object_name_phrase: str) -> str:
12
  print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
13
  label = object_name_phrase.strip().lower().split()[-1]
@@ -15,6 +16,7 @@ def get_canonical_label(object_name_phrase: str) -> str:
15
  print(f" [Label] βœ… Extracted label: '{label}'")
16
  return label if label else "unknown"
17
 
 
18
  def download_image_from_url(image_url: str) -> Image.Image:
19
  print(f" [Download] Downloading image from: {image_url[:80]}...")
20
  response = requests.get(image_url)
@@ -24,91 +26,112 @@ def download_image_from_url(image_url: str) -> Image.Image:
24
  print(" [Download] βœ… Image downloaded and standardized to RGB.")
25
  return image_rgb
26
 
 
27
  def detect_and_crop(image: Image.Image, object_name: str, models: dict) -> Image.Image:
28
  print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
29
  image_np = np.array(image.convert("RGB"))
30
  height, width = image_np.shape[:2]
 
31
  prompt = [[f"a {object_name}"]]
32
- inputs = models['processor_gnd'](images=image, text=prompt, return_tensors="pt").to(models['device'])
 
 
 
 
33
 
34
  with torch.no_grad():
35
- outputs = models['model_gnd'](
36
- **inputs,
 
 
37
  box_threshold=0.4,
38
- text_threshold=0.3
 
39
  )
40
 
41
- results = models['processor_gnd'].post_process_grounded_object_detection(
42
- outputs=outputs,
43
- input_ids=inputs.input_ids,
44
- target_sizes=[(height, width)]
45
- )
46
-
47
  if not results or len(results[0]['boxes']) == 0:
48
  print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
49
  return image
 
50
  result = results[0]
51
  scores = result['scores']
52
  max_idx = int(torch.argmax(scores))
53
  box = result['boxes'][max_idx].cpu().numpy().astype(int)
54
  print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
 
55
  x1, y1, x2, y2 = box
56
  models['predictor'].set_image(image_np)
57
  box_prompt = np.array([[x1, y1, x2, y2]])
 
58
  masks, _, _ = models['predictor'].predict(box=box_prompt, multimask_output=False)
59
  mask = masks[0]
60
  mask_bool = mask > 0
 
61
  cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
62
  cropped_img_rgba[:, :, :3] = image_np
63
  cropped_img_rgba[:, :, 3] = mask_bool * 255
64
  cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
 
65
  return Image.fromarray(cropped_img_rgba, 'RGBA')
66
 
 
67
  def extract_features(segmented_image: Image.Image) -> dict:
68
  image_rgba = np.array(segmented_image)
69
- if image_rgba.shape[2] != 4: raise ValueError("Segmented image must be RGBA")
 
 
70
  b, g, r, a = cv2.split(image_rgba)
71
  image_rgb = cv2.merge((b, g, r))
72
  mask = a
 
73
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
74
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
75
- if not contours:
76
- # If no contours are found, return zero-filled features
77
- print(" [Features] ⚠ Warning: No contours found in segmented image. Returning zero features.")
78
- return {
79
- "shape_features": [0.0] * 7,
80
- "color_features": [0.0] * 512, # 8*8*8
81
- "texture_features": [0.0] * 26
82
- }
83
- hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten()
84
- color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
85
  cv2.normalize(color_hist, color_hist)
86
  color_hist = color_hist.flatten()
 
87
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
88
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
89
  (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
90
  texture_hist = texture_hist.astype("float32")
91
  texture_hist /= (texture_hist.sum() + 1e-6)
 
92
  return {
93
  "shape_features": hu_moments.tolist(),
94
  "color_features": color_hist.tolist(),
95
  "texture_features": texture_hist.tolist()
96
  }
97
 
 
98
  def get_text_embedding(text: str, models: dict) -> list:
99
  print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
100
- text_with_instruction = f"Represent this description of a lost item for similarity search: {text}"
101
- inputs = models['tokenizer_text'](text_with_instruction, return_tensors='pt', padding=True, truncation=True, max_length=512).to(models['device'])
 
 
 
 
 
 
 
 
102
  with torch.no_grad():
103
  outputs = models['model_text'](**inputs)
104
  embedding = outputs.last_hidden_state[:, 0, :]
105
  embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
 
106
  print(" [Embedding] βœ… Text embedding generated.")
107
  return embedding.cpu().numpy()[0].tolist()
108
 
 
109
  def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
110
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
111
 
 
112
  def stretch_image_score(score):
113
- if score < 0.4 or score == 1.0: return score
114
- return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)
 
 
3
  import cv2
4
  from skimage import feature
5
  from io import BytesIO
6
+ from PIL import Image, ImageFile
7
  import torch
8
+
9
  ImageFile.LOAD_TRUNCATED_IMAGES = True
10
 
11
+
12
  def get_canonical_label(object_name_phrase: str) -> str:
13
  print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
14
  label = object_name_phrase.strip().lower().split()[-1]
 
16
  print(f" [Label] βœ… Extracted label: '{label}'")
17
  return label if label else "unknown"
18
 
19
+
20
  def download_image_from_url(image_url: str) -> Image.Image:
21
  print(f" [Download] Downloading image from: {image_url[:80]}...")
22
  response = requests.get(image_url)
 
26
  print(" [Download] βœ… Image downloaded and standardized to RGB.")
27
  return image_rgb
28
 
29
+
30
  def detect_and_crop(image: Image.Image, object_name: str, models: dict) -> Image.Image:
31
  print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
32
  image_np = np.array(image.convert("RGB"))
33
  height, width = image_np.shape[:2]
34
+
35
  prompt = [[f"a {object_name}"]]
36
+ inputs = models['processor_gnd'](
37
+ images=image,
38
+ text=prompt,
39
+ return_tensors="pt"
40
+ ).to(models['device'])
41
 
42
  with torch.no_grad():
43
+ outputs = models['model_gnd'](**inputs)
44
+ results = models['processor_gnd'].post_process_grounded_object_detection(
45
+ outputs,
46
+ inputs.input_ids,
47
  box_threshold=0.4,
48
+ text_threshold=0.3,
49
+ target_sizes=[(height, width)]
50
  )
51
 
 
 
 
 
 
 
52
  if not results or len(results[0]['boxes']) == 0:
53
  print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
54
  return image
55
+
56
  result = results[0]
57
  scores = result['scores']
58
  max_idx = int(torch.argmax(scores))
59
  box = result['boxes'][max_idx].cpu().numpy().astype(int)
60
  print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
61
+
62
  x1, y1, x2, y2 = box
63
  models['predictor'].set_image(image_np)
64
  box_prompt = np.array([[x1, y1, x2, y2]])
65
+
66
  masks, _, _ = models['predictor'].predict(box=box_prompt, multimask_output=False)
67
  mask = masks[0]
68
  mask_bool = mask > 0
69
+
70
  cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
71
  cropped_img_rgba[:, :, :3] = image_np
72
  cropped_img_rgba[:, :, 3] = mask_bool * 255
73
  cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
74
+
75
  return Image.fromarray(cropped_img_rgba, 'RGBA')
76
 
77
+
78
  def extract_features(segmented_image: Image.Image) -> dict:
79
  image_rgba = np.array(segmented_image)
80
+ if image_rgba.shape[2] != 4:
81
+ raise ValueError("Segmented image must be RGBA")
82
+
83
  b, g, r, a = cv2.split(image_rgba)
84
  image_rgb = cv2.merge((b, g, r))
85
  mask = a
86
+
87
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
88
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
89
+ hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
90
+
91
+ color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8],
92
+ [0, 256, 0, 256, 0, 256])
 
 
 
 
 
 
93
  cv2.normalize(color_hist, color_hist)
94
  color_hist = color_hist.flatten()
95
+
96
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
97
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
98
  (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
99
  texture_hist = texture_hist.astype("float32")
100
  texture_hist /= (texture_hist.sum() + 1e-6)
101
+
102
  return {
103
  "shape_features": hu_moments.tolist(),
104
  "color_features": color_hist.tolist(),
105
  "texture_features": texture_hist.tolist()
106
  }
107
 
108
+
109
  def get_text_embedding(text: str, models: dict) -> list:
110
  print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
111
+ text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"
112
+
113
+ inputs = models['tokenizer_text'](
114
+ text_with_instruction,
115
+ return_tensors='pt',
116
+ padding=True,
117
+ truncation=True,
118
+ max_length=512
119
+ ).to(models['device'])
120
+
121
  with torch.no_grad():
122
  outputs = models['model_text'](**inputs)
123
  embedding = outputs.last_hidden_state[:, 0, :]
124
  embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
125
+
126
  print(" [Embedding] βœ… Text embedding generated.")
127
  return embedding.cpu().numpy()[0].tolist()
128
 
129
+
130
  def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
131
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
132
 
133
+
134
  def stretch_image_score(score):
135
+ if score < 0.4 or score == 1.0:
136
+ return score
137
+ return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)