sohamnk commited on
Commit
9ad8225
Β·
verified Β·
1 Parent(s): ccb5338

Update pipeline/logic.py

Browse files
Files changed (1) hide show
  1. pipeline/logic.py +11 -54
pipeline/logic.py CHANGED
@@ -1,17 +1,13 @@
1
- from PIL import ImageFile
2
- ImageFile.LOAD_TRUNCATED_IMAGES = True
3
-
4
  import numpy as np
5
  import requests
6
  import cv2
7
  from skimage import feature
8
  from io import BytesIO
9
- from PIL import Image, ImageFile
10
  import torch
11
-
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
14
-
15
  def get_canonical_label(object_name_phrase: str) -> str:
16
  print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
17
  label = object_name_phrase.strip().lower().split()[-1]
@@ -19,7 +15,6 @@ def get_canonical_label(object_name_phrase: str) -> str:
19
  print(f" [Label] βœ… Extracted label: '{label}'")
20
  return label if label else "unknown"
21
 
22
-
23
  def download_image_from_url(image_url: str) -> Image.Image:
24
  print(f" [Download] Downloading image from: {image_url[:80]}...")
25
  response = requests.get(image_url)
@@ -29,112 +24,74 @@ def download_image_from_url(image_url: str) -> Image.Image:
29
  print(" [Download] βœ… Image downloaded and standardized to RGB.")
30
  return image_rgb
31
 
32
-
33
  def detect_and_crop(image: Image.Image, object_name: str, models: dict) -> Image.Image:
34
  print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
35
  image_np = np.array(image.convert("RGB"))
36
  height, width = image_np.shape[:2]
37
-
38
  prompt = [[f"a {object_name}"]]
39
- inputs = models['processor_gnd'](
40
- images=image,
41
- text=prompt,
42
- return_tensors="pt"
43
- ).to(models['device'])
44
-
45
  with torch.no_grad():
46
  outputs = models['model_gnd'](**inputs)
47
- results = models['processor_gnd'].post_process_grounded_object_detection(
48
- outputs,
49
- inputs.input_ids,
50
- box_threshold=0.4,
51
- text_threshold=0.3,
52
- target_sizes=[(height, width)]
53
- )
54
-
55
  if not results or len(results[0]['boxes']) == 0:
56
  print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
57
  return image
58
-
59
  result = results[0]
60
  scores = result['scores']
61
  max_idx = int(torch.argmax(scores))
62
  box = result['boxes'][max_idx].cpu().numpy().astype(int)
63
  print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
64
-
65
  x1, y1, x2, y2 = box
66
  models['predictor'].set_image(image_np)
67
  box_prompt = np.array([[x1, y1, x2, y2]])
68
-
69
  masks, _, _ = models['predictor'].predict(box=box_prompt, multimask_output=False)
70
  mask = masks[0]
71
  mask_bool = mask > 0
72
-
73
  cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
74
  cropped_img_rgba[:, :, :3] = image_np
75
  cropped_img_rgba[:, :, 3] = mask_bool * 255
76
  cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
77
-
78
  return Image.fromarray(cropped_img_rgba, 'RGBA')
79
 
80
-
81
  def extract_features(segmented_image: Image.Image) -> dict:
82
  image_rgba = np.array(segmented_image)
83
- if image_rgba.shape[2] != 4:
84
- raise ValueError("Segmented image must be RGBA")
85
-
86
  b, g, r, a = cv2.split(image_rgba)
87
  image_rgb = cv2.merge((b, g, r))
88
  mask = a
89
-
90
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
91
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
92
  hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
93
-
94
- color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8],
95
- [0, 256, 0, 256, 0, 256])
96
  cv2.normalize(color_hist, color_hist)
97
  color_hist = color_hist.flatten()
98
-
99
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
100
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
101
  (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
102
  texture_hist = texture_hist.astype("float32")
103
  texture_hist /= (texture_hist.sum() + 1e-6)
104
-
105
  return {
106
  "shape_features": hu_moments.tolist(),
107
  "color_features": color_hist.tolist(),
108
  "texture_features": texture_hist.tolist()
109
  }
110
 
111
-
112
  def get_text_embedding(text: str, models: dict) -> list:
113
  print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
114
  text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"
115
-
116
- inputs = models['tokenizer_text'](
117
- text_with_instruction,
118
- return_tensors='pt',
119
- padding=True,
120
- truncation=True,
121
- max_length=512
122
- ).to(models['device'])
123
-
124
  with torch.no_grad():
125
  outputs = models['model_text'](**inputs)
126
  embedding = outputs.last_hidden_state[:, 0, :]
127
  embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
128
-
129
  print(" [Embedding] βœ… Text embedding generated.")
130
  return embedding.cpu().numpy()[0].tolist()
131
 
132
-
133
  def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
134
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
135
 
136
-
137
  def stretch_image_score(score):
138
- if score < 0.4 or score == 1.0:
139
- return score
140
- return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)
 
 
 
 
1
  import numpy as np
2
  import requests
3
  import cv2
4
  from skimage import feature
5
  from io import BytesIO
6
+ from PIL import Image
7
  import torch
8
+ from PIL import ImageFile
9
  ImageFile.LOAD_TRUNCATED_IMAGES = True
10
 
 
11
  def get_canonical_label(object_name_phrase: str) -> str:
12
  print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
13
  label = object_name_phrase.strip().lower().split()[-1]
 
15
  print(f" [Label] βœ… Extracted label: '{label}'")
16
  return label if label else "unknown"
17
 
 
18
  def download_image_from_url(image_url: str) -> Image.Image:
19
  print(f" [Download] Downloading image from: {image_url[:80]}...")
20
  response = requests.get(image_url)
 
24
  print(" [Download] βœ… Image downloaded and standardized to RGB.")
25
  return image_rgb
26
 
 
27
  def detect_and_crop(image: Image.Image, object_name: str, models: dict) -> Image.Image:
28
  print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
29
  image_np = np.array(image.convert("RGB"))
30
  height, width = image_np.shape[:2]
 
31
  prompt = [[f"a {object_name}"]]
32
+ inputs = models['processor_gnd'](images=image, text=prompt, return_tensors="pt").to(models['device'])
 
 
 
 
 
33
  with torch.no_grad():
34
  outputs = models['model_gnd'](**inputs)
35
+ results = models['processor_gnd'].post_process_grounded_object_detection(
36
+ outputs, inputs.input_ids, box_threshold=0.4, text_threshold=0.3, target_sizes=[(height, width)]
37
+ )
 
 
 
 
 
38
  if not results or len(results[0]['boxes']) == 0:
39
  print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
40
  return image
 
41
  result = results[0]
42
  scores = result['scores']
43
  max_idx = int(torch.argmax(scores))
44
  box = result['boxes'][max_idx].cpu().numpy().astype(int)
45
  print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")
 
46
  x1, y1, x2, y2 = box
47
  models['predictor'].set_image(image_np)
48
  box_prompt = np.array([[x1, y1, x2, y2]])
 
49
  masks, _, _ = models['predictor'].predict(box=box_prompt, multimask_output=False)
50
  mask = masks[0]
51
  mask_bool = mask > 0
 
52
  cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
53
  cropped_img_rgba[:, :, :3] = image_np
54
  cropped_img_rgba[:, :, 3] = mask_bool * 255
55
  cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]
 
56
  return Image.fromarray(cropped_img_rgba, 'RGBA')
57
 
 
58
  def extract_features(segmented_image: Image.Image) -> dict:
59
  image_rgba = np.array(segmented_image)
60
+ if image_rgba.shape[2] != 4: raise ValueError("Segmented image must be RGBA")
 
 
61
  b, g, r, a = cv2.split(image_rgba)
62
  image_rgb = cv2.merge((b, g, r))
63
  mask = a
 
64
  gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
65
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
  hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
67
+ color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
 
 
68
  cv2.normalize(color_hist, color_hist)
69
  color_hist = color_hist.flatten()
 
70
  gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
71
  lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
72
  (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
73
  texture_hist = texture_hist.astype("float32")
74
  texture_hist /= (texture_hist.sum() + 1e-6)
 
75
  return {
76
  "shape_features": hu_moments.tolist(),
77
  "color_features": color_hist.tolist(),
78
  "texture_features": texture_hist.tolist()
79
  }
80
 
 
81
  def get_text_embedding(text: str, models: dict) -> list:
82
  print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
83
  text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"
84
+ inputs = models['tokenizer_text'](text_with_instruction, return_tensors='pt', padding=True, truncation=True, max_length=512).to(models['device'])
 
 
 
 
 
 
 
 
85
  with torch.no_grad():
86
  outputs = models['model_text'](**inputs)
87
  embedding = outputs.last_hidden_state[:, 0, :]
88
  embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
 
89
  print(" [Embedding] βœ… Text embedding generated.")
90
  return embedding.cpu().numpy()[0].tolist()
91
 
 
92
  def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
93
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
94
 
 
95
  def stretch_image_score(score):
96
+ if score < 0.4 or score == 1.0: return score
97
+ return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)