Guill-Bla commited on
Commit
ca9f528
·
verified ·
1 Parent(s): 3e4ce08

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +98 -44
tasks/image.py CHANGED
@@ -6,6 +6,8 @@ from sklearn.metrics import accuracy_score
6
  import random
7
  import os
8
 
 
 
9
  from ultralytics import YOLO
10
  from .utils.evaluation import ImageEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -45,7 +47,19 @@ def preprocess(image):
45
  # Return as a PIL Image for feature extractor compatibility
46
  return Image.fromarray((image * 255).astype(np.uint8))
47
 
48
-
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_bounding_boxes_from_mask(mask):
51
  """Extract bounding boxes from a binary mask."""
@@ -126,7 +140,7 @@ async def evaluate_image(request: ImageEvaluationRequest):
126
 
127
  # Load and prepare the dataset
128
  dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
129
-
130
  # Split dataset
131
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
132
  test_dataset = dataset["val"]#train_test["test"]
@@ -139,67 +153,107 @@ async def evaluate_image(request: ImageEvaluationRequest):
139
  # YOUR MODEL INFERENCE CODE HERE
140
  # Update the code below to replace the random baseline with your model inference
141
  #--------------------------------------------------------------------------------------------
 
 
 
142
  predictions = []
143
  true_labels = []
144
  pred_boxes = []
145
  true_boxes_list = []
146
 
147
- for example in test_dataset:
148
- # Extract image and annotations
149
- image = example["image"]
 
 
 
 
150
 
151
- original_shape = image.size
152
- annotation = example.get("annotations", "").strip()
153
- has_smoke = len(annotation) > 0
154
- true_labels.append(1 if has_smoke else 0)
155
-
156
- if has_smoke:
157
- image_true_boxes = parse_boxes(annotation)
158
- if image_true_boxes:
159
- true_boxes_list.append(image_true_boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  else:
161
  true_boxes_list.append([])
162
- else:
163
- true_boxes_list.append([])
164
 
165
- # Model Inference
 
 
 
 
 
 
 
166
 
167
- # Preprocess image
168
- image = preprocess(image)
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # Ensure correct feature extraction
171
- image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
172
 
173
- # Perform inference
174
- with torch.no_grad():
175
- outputs = model(pixel_values=image_input)
176
- logits = outputs.logits
177
 
178
- # Threshold and process the segmentation mask
179
- probabilities = torch.sigmoid(logits)
180
- predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
181
- predicted_mask_resized = cv2.resize(predicted_mask, (512,512), interpolation=cv2.INTER_NEAREST)
182
 
183
- # Extract bounding boxes
184
- predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
185
- pred_boxes.append(predicted_boxes)
186
 
187
- # Smoke prediction based on bounding box presence
188
- predictions.append(1 if len(predicted_boxes) > 0 else 0)
189
- print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
190
 
191
 
192
- # Filter only valid box pairs
193
- filtered_true_boxes_list = []
194
- filtered_pred_boxes = []
195
 
196
- for true_boxes, pred_boxes_entry in zip(true_boxes_list, pred_boxes):
197
- if true_boxes and pred_boxes_entry:
198
- filtered_true_boxes_list.append(true_boxes)
199
- filtered_pred_boxes.append(pred_boxes_entry)
200
 
201
- true_boxes_list = filtered_true_boxes_list
202
- pred_boxes = filtered_pred_boxes
203
 
204
 
205
  #--------------------------------------------------------------------------------------------
 
6
  import random
7
  import os
8
 
9
+ from torch.utils.data import DataLoader
10
+
11
  from ultralytics import YOLO
12
  from .utils.evaluation import ImageEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
47
  # Return as a PIL Image for feature extractor compatibility
48
  return Image.fromarray((image * 255).astype(np.uint8))
49
 
50
+ def preprocess_batch(images):
51
+ """
52
+ Preprocess a batch of images for MobileViT inference.
53
+ Resize to a fixed size (512, 512) and return as PIL Images.
54
+ """
55
+ preprocessed_images = []
56
+ for image in images:
57
+ resized_image = image.resize((512, 512))
58
+ image_array = np.array(resized_image)[:, :, ::-1] # Convert RGB to BGR
59
+ image_float = np.array(image_array, dtype=np.float32) / 255.0
60
+ processed_image = Image.fromarray((image_float * 255).astype(np.uint8))
61
+ preprocessed_images.append(processed_image)
62
+ return preprocessed_images
63
 
64
  def get_bounding_boxes_from_mask(mask):
65
  """Extract bounding boxes from a binary mask."""
 
140
 
141
  # Load and prepare the dataset
142
  dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
143
+
144
  # Split dataset
145
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
146
  test_dataset = dataset["val"]#train_test["test"]
 
153
  # YOUR MODEL INFERENCE CODE HERE
154
  # Update the code below to replace the random baseline with your model inference
155
  #--------------------------------------------------------------------------------------------
156
+
157
+ dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
158
+
159
  predictions = []
160
  true_labels = []
161
  pred_boxes = []
162
  true_boxes_list = []
163
 
164
+ for batch_idx, batch_examples in enumerate(dataloader):
165
+ # Extract images and preprocess
166
+ images = [example["image"] for example in batch_examples]
167
+ annotations = [example.get("annotations", "").strip() for example in batch_examples]
168
+
169
+ has_smoke_list = [len(annotation) > 0 for annotation in annotations]
170
+ true_labels.extend([1 if has_smoke else 0 for has_smoke in has_smoke_list])
171
 
172
+ # Preprocess images and extract features
173
+ preprocessed_images = preprocess_batch(images)
174
+ image_inputs = feature_extractor(images=preprocessed_images, return_tensors="pt", padding=True).pixel_values
175
+
176
+ # Perform inference
177
+ with torch.no_grad():
178
+ outputs = model(pixel_values=image_inputs)
179
+ logits = outputs.logits
180
+
181
+ # Threshold and process the segmentation masks
182
+ probabilities = torch.sigmoid(logits)
183
+ batch_predicted_masks = (probabilities[:, 1, :, :] > 0.30).cpu().numpy().astype(np.uint8)
184
+
185
+ for mask in batch_predicted_masks:
186
+ mask_resized = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
187
+ predicted_boxes = get_bounding_boxes_from_mask(mask_resized)
188
+ pred_boxes.append(predicted_boxes)
189
+
190
+ # Append smoke detection based on bounding boxes
191
+ predictions.append(1 if len(predicted_boxes) > 0 else 0)
192
+ print(f"Batch {batch_idx + 1}, Image Prediction: {1 if len(predicted_boxes) > 0 else 0}")
193
+
194
+ # Parse true boxes for this batch
195
+ for annotation in annotations:
196
+ if len(annotation) > 0:
197
+ true_boxes_list.append(parse_boxes(annotation))
198
  else:
199
  true_boxes_list.append([])
 
 
200
 
201
+ # for example in test_dataset:
202
+ # # Extract image and annotations
203
+ # image = example["image"]
204
+
205
+ # original_shape = image.size
206
+ # annotation = example.get("annotations", "").strip()
207
+ # has_smoke = len(annotation) > 0
208
+ # true_labels.append(1 if has_smoke else 0)
209
 
210
+ # if has_smoke:
211
+ # image_true_boxes = parse_boxes(annotation)
212
+ # if image_true_boxes:
213
+ # true_boxes_list.append(image_true_boxes)
214
+ # else:
215
+ # true_boxes_list.append([])
216
+ # else:
217
+ # true_boxes_list.append([])
218
+
219
+ # # Model Inference
220
+
221
+ # # Preprocess image
222
+ # image = preprocess(image)
223
 
224
+ # # Ensure correct feature extraction
225
+ # image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
226
 
227
+ # # Perform inference
228
+ # with torch.no_grad():
229
+ # outputs = model(pixel_values=image_input)
230
+ # logits = outputs.logits
231
 
232
+ # # Threshold and process the segmentation mask
233
+ # probabilities = torch.sigmoid(logits)
234
+ # predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
235
+ # predicted_mask_resized = cv2.resize(predicted_mask, (512,512), interpolation=cv2.INTER_NEAREST)
236
 
237
+ # # Extract bounding boxes
238
+ # predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
239
+ # pred_boxes.append(predicted_boxes)
240
 
241
+ # # Smoke prediction based on bounding box presence
242
+ # predictions.append(1 if len(predicted_boxes) > 0 else 0)
243
+ # print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
244
 
245
 
246
+ # # Filter only valid box pairs
247
+ # filtered_true_boxes_list = []
248
+ # filtered_pred_boxes = []
249
 
250
+ # for true_boxes, pred_boxes_entry in zip(true_boxes_list, pred_boxes):
251
+ # if true_boxes and pred_boxes_entry:
252
+ # filtered_true_boxes_list.append(true_boxes)
253
+ # filtered_pred_boxes.append(pred_boxes_entry)
254
 
255
+ # true_boxes_list = filtered_true_boxes_list
256
+ # pred_boxes = filtered_pred_boxes
257
 
258
 
259
  #--------------------------------------------------------------------------------------------