Guill-Bla commited on
Commit
4b9ac2a
·
verified ·
1 Parent(s): 5c34ca1

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +22 -20
tasks/image.py CHANGED
@@ -35,13 +35,16 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
35
  model.eval()
36
 
37
  def preprocess(image):
 
38
  image = image.resize((512, 512))
39
- image = np.array(image)[:, :, ::-1] # RGB to BGR
 
 
40
  image = np.array(image, dtype=np.float32) / 255.0
41
 
42
- # Convert back to PIL Image to maintain compatibility with feature extractor
43
- image = Image.fromarray((image * 255).astype(np.uint8))
44
- return image
45
 
46
 
47
  def get_bounding_boxes_from_mask(mask):
@@ -145,16 +148,11 @@ async def evaluate_image(request: ImageEvaluationRequest):
145
  # Extract image and annotations
146
  image = example["image"]
147
 
148
- original_shape = image.size
149
- image = preprocess(image)
150
-
151
  annotation = example.get("annotations", "").strip()
152
-
153
-
154
  has_smoke = len(annotation) > 0
155
  true_labels.append(1 if has_smoke else 0)
156
 
157
-
158
  if has_smoke:
159
  image_true_boxes = parse_boxes(annotation)
160
  if image_true_boxes:
@@ -165,26 +163,30 @@ async def evaluate_image(request: ImageEvaluationRequest):
165
  true_boxes_list.append([])
166
 
167
  # Model Inference
168
- # image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
169
- image_input = feature_extractor(images=[image], return_tensors="pt", padding=True).pixel_values
170
-
 
 
 
 
 
171
  with torch.no_grad():
172
  outputs = model(pixel_values=image_input)
173
  logits = outputs.logits
174
-
 
175
  probabilities = torch.sigmoid(logits)
176
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
177
- # predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
178
  predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
179
-
180
 
181
- # Extract predicted bounding boxes
182
  predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
183
  pred_boxes.append(predicted_boxes)
184
-
185
- # Binary prediction for smoke detection
 
186
  print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
187
- predictions.append(1 if len(predicted_boxes) > 0 else 0)
188
 
189
 
190
  # Filter only valid box pairs
 
35
  model.eval()
36
 
37
  def preprocess(image):
38
+ # Ensure input image is resized to a fixed size (512, 512)
39
  image = image.resize((512, 512))
40
+
41
+ # Convert to NumPy and ensure BGR normalization
42
+ image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
43
  image = np.array(image, dtype=np.float32) / 255.0
44
 
45
+ # Return as a PIL Image for feature extractor compatibility
46
+ return Image.fromarray((image * 255).astype(np.uint8))
47
+
48
 
49
 
50
  def get_bounding_boxes_from_mask(mask):
 
148
  # Extract image and annotations
149
  image = example["image"]
150
 
151
+ original_shape = image.size
 
 
152
  annotation = example.get("annotations", "").strip()
 
 
153
  has_smoke = len(annotation) > 0
154
  true_labels.append(1 if has_smoke else 0)
155
 
 
156
  if has_smoke:
157
  image_true_boxes = parse_boxes(annotation)
158
  if image_true_boxes:
 
163
  true_boxes_list.append([])
164
 
165
  # Model Inference
166
+
167
+ # Preprocess image
168
+ image = preprocess(image)
169
+
170
+ # Ensure correct feature extraction
171
+ image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
172
+
173
+ # Perform inference
174
  with torch.no_grad():
175
  outputs = model(pixel_values=image_input)
176
  logits = outputs.logits
177
+
178
+ # Threshold and process the segmentation mask
179
  probabilities = torch.sigmoid(logits)
180
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
 
181
  predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
 
182
 
183
+ # Extract bounding boxes
184
  predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
185
  pred_boxes.append(predicted_boxes)
186
+
187
+ # Smoke prediction based on bounding box presence
188
+ predictions.append(1 if len(predicted_boxes) > 0 else 0)
189
  print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
 
190
 
191
 
192
  # Filter only valid box pairs