Guill-Bla commited on
Commit
de2943e
·
verified ·
1 Parent(s): bb54cea

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +36 -20
tasks/image.py CHANGED
@@ -36,6 +36,29 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile
36
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
37
  model.eval()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def preprocess(image):
40
  # Ensure input image is resized to a fixed size (512, 512)
41
  image = image.resize((512, 512))
@@ -153,39 +176,32 @@ async def evaluate_image(request: ImageEvaluationRequest):
153
  # YOUR MODEL INFERENCE CODE HERE
154
  # Update the code below to replace the random baseline with your model inference
155
  #--------------------------------------------------------------------------------------------
156
-
157
- dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
158
 
159
  predictions = []
160
  true_labels = []
161
  pred_boxes = []
162
  true_boxes_list = []
163
 
164
- for batch_idx, batch_examples in enumerate(dataloader):
165
- # Extract images and preprocess
166
- images = [example["image"] for example in batch_examples]
167
- annotations = [example.get("annotations", "").strip() for example in batch_examples]
168
-
169
- has_smoke_list = [len(annotation) > 0 for annotation in annotations]
170
- true_labels.extend([1 if has_smoke else 0 for has_smoke in has_smoke_list])
171
-
172
- # Preprocess images and extract features
173
- preprocessed_images = preprocess_batch(images)
174
- image_inputs = feature_extractor(images=preprocessed_images, return_tensors="pt", padding=True).pixel_values
175
-
176
  # Perform inference
177
  with torch.no_grad():
178
  outputs = model(pixel_values=image_inputs)
179
  logits = outputs.logits
180
-
181
- # Threshold and process the segmentation masks
182
  probabilities = torch.sigmoid(logits)
183
  batch_predicted_masks = (probabilities[:, 1, :, :] > 0.30).cpu().numpy().astype(np.uint8)
184
-
185
- for mask in batch_predicted_masks:
186
- mask_resized = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
187
- predicted_boxes = get_bounding_boxes_from_mask(mask_resized)
 
188
  pred_boxes.append(predicted_boxes)
 
 
189
 
190
  # Append smoke detection based on bounding boxes
191
  predictions.append(1 if len(predicted_boxes) > 0 else 0)
 
36
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
37
  model.eval()
38
 
39
+ from torch.utils.data import Dataset
40
+
41
+ class SmokeDataset(Dataset):
42
+ def __init__(self, dataset):
43
+ self.dataset = dataset
44
+
45
+ def __len__(self):
46
+ return len(self.dataset)
47
+
48
+ def __getitem__(self, idx):
49
+ example = self.dataset[idx]
50
+ image = example["image"]
51
+ annotation = example.get("annotations", "").strip()
52
+
53
+ # Resize and preprocess the image directly here
54
+ image = image.resize((512, 512))
55
+ image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
56
+ image = np.array(image, dtype=np.float32) / 255.0
57
+
58
+ # Return both the preprocessed image and annotation
59
+ return torch.tensor(image).permute(2, 0, 1), annotation
60
+
61
+
62
  def preprocess(image):
63
  # Ensure input image is resized to a fixed size (512, 512)
64
  image = image.resize((512, 512))
 
176
  # YOUR MODEL INFERENCE CODE HERE
177
  # Update the code below to replace the random baseline with your model inference
178
  #--------------------------------------------------------------------------------------------
179
+ smoke_dataset = SmokeDataset(test_dataset)
180
+ dataloader = DataLoader(smoke_dataset, batch_size=16, shuffle=False)
181
 
182
  predictions = []
183
  true_labels = []
184
  pred_boxes = []
185
  true_boxes_list = []
186
 
187
+ for batch_images, batch_annotations in dataloader:
188
+ image_inputs = feature_extractor(images=batch_images, return_tensors="pt", padding=True).pixel_values
189
+
 
 
 
 
 
 
 
 
 
190
  # Perform inference
191
  with torch.no_grad():
192
  outputs = model(pixel_values=image_inputs)
193
  logits = outputs.logits
194
+
 
195
  probabilities = torch.sigmoid(logits)
196
  batch_predicted_masks = (probabilities[:, 1, :, :] > 0.30).cpu().numpy().astype(np.uint8)
197
+
198
+ # Post-process predictions and compute metrics
199
+ for mask, annotation in zip(batch_predicted_masks, batch_annotations):
200
+ predicted_mask_resized = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
201
+ predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
202
  pred_boxes.append(predicted_boxes)
203
+ predictions.append(1 if len(predicted_boxes) > 0 else 0)
204
+ true_labels.append(1 if annotation else 0)
205
 
206
  # Append smoke detection based on bounding boxes
207
  predictions.append(1 if len(predicted_boxes) > 0 else 0)