Guill-Bla commited on
Commit
f491cd6
·
verified ·
1 Parent(s): 33584bc

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +49 -20
tasks/image.py CHANGED
@@ -6,20 +6,40 @@ from sklearn.metrics import accuracy_score
6
  import random
7
  import os
8
 
9
- from ultralytics import YOLO # Import YOLO
10
- from .utils.evaluation import ImageEvaluationRequest
11
- from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
 
 
 
 
 
12
 
13
  from dotenv import load_dotenv
14
  load_dotenv()
15
 
16
  router = APIRouter()
17
 
18
- DESCRIPTION = "YOLO Smoke Detection"
19
  ROUTE = "/image"
20
 
21
- yolo_model = YOLO("best.pt")
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
23
  def parse_boxes(annotation_string):
24
  """Parse multiple boxes from a single annotation string.
25
  Each box has 5 values: class_id, x_center, y_center, width, height"""
@@ -93,7 +113,7 @@ async def evaluate_image(request: ImageEvaluationRequest):
93
  # Split dataset
94
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
95
  test_dataset = dataset["val"]#train_test["test"]
96
-
97
  # Start tracking emissions
98
  tracker.start()
99
  tracker.start_task("inference")
@@ -126,28 +146,37 @@ async def evaluate_image(request: ImageEvaluationRequest):
126
  else:
127
  true_boxes_list.append([])
128
 
129
- # results = yolo_model .predict(image, verbose=False) # INFERENCE - prediction
130
- results = yolo_model.predict(image) # INFERENCE - prediction
131
-
132
- if len(results[0].boxes):
133
- pred_box = results[0].boxes.xywhn[0].cpu().numpy().tolist()
134
- predictions.append(1)
135
- pred_boxes.append(pred_box)
136
- else:
137
- predictions.append(0)
138
- pred_boxes.append([])
 
 
 
 
 
 
 
 
139
 
 
140
  filtered_true_boxes_list = []
141
  filtered_pred_boxes = []
142
-
143
- for true_boxes, pred_boxes_entry in zip(true_boxes_list, pred_boxes): # Only see when annotation(s) is/are both on true label and prediction
144
- if true_boxes and pred_boxes_entry:
145
  filtered_true_boxes_list.append(true_boxes)
146
  filtered_pred_boxes.append(pred_boxes_entry)
147
 
148
-
149
  true_boxes_list = filtered_true_boxes_list
150
  pred_boxes = filtered_pred_boxes
 
151
 
152
  #--------------------------------------------------------------------------------------------
153
  # YOUR MODEL INFERENCE STOPS HERE
 
6
  import random
7
  import os
8
 
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
14
+ import cv2
15
+ from tqdm import tqdm
16
+ from dataset import WildfireSmokeDataset
17
+ from torch.utils.data import DataLoader
18
 
19
  from dotenv import load_dotenv
20
  load_dotenv()
21
 
22
  router = APIRouter()
23
 
24
+ DESCRIPTION = "Mobile-ViT Smoke Detection"
25
  ROUTE = "/image"
26
 
27
+ model_path = "mobilevit_segmentation_full_data.pth"
28
+ feature_extractor = MobileViTImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
29
+ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
30
+ model.load_state_dict(torch.load(model_path))
31
+ model.eval()
32
 
33
+ def get_bounding_boxes_from_mask(mask):
34
+ """Extract bounding boxes from a binary mask."""
35
+ pred_boxes = []
36
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
37
+ for contour in contours:
38
+ if len(contour) > 5: # Ignore small/noisy contours
39
+ x, y, w, h = cv2.boundingRect(contour)
40
+ pred_boxes.append((x, y, x + w, y + h))
41
+ return pred_boxes
42
+
43
  def parse_boxes(annotation_string):
44
  """Parse multiple boxes from a single annotation string.
45
  Each box has 5 values: class_id, x_center, y_center, width, height"""
 
113
  # Split dataset
114
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
115
  test_dataset = dataset["val"]#train_test["test"]
116
+
117
  # Start tracking emissions
118
  tracker.start()
119
  tracker.start_task("inference")
 
146
  else:
147
  true_boxes_list.append([])
148
 
149
+ # Model Inference
150
+ image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
151
+ with torch.no_grad():
152
+ outputs = model(pixel_values=image_input)
153
+ logits = outputs.logits
154
+
155
+ probabilities = torch.sigmoid(logits)
156
+ predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
157
+ predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
158
+
159
+ # Extract predicted bounding boxes
160
+ predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
161
+ pred_boxes.append(predicted_boxes)
162
+
163
+ # Binary prediction for smoke detection
164
+ print(1 if len(predicted_boxes) > 0 else 0)
165
+ predictions.append(1 if len(predicted_boxes) > 0 else 0)
166
+
167
 
168
+ # Filter only valid box pairs
169
  filtered_true_boxes_list = []
170
  filtered_pred_boxes = []
171
+
172
+ for true_boxes, pred_boxes_entry in zip(true_boxes_list, pred_boxes):
173
+ if true_boxes and pred_boxes_entry:
174
  filtered_true_boxes_list.append(true_boxes)
175
  filtered_pred_boxes.append(pred_boxes_entry)
176
 
 
177
  true_boxes_list = filtered_true_boxes_list
178
  pred_boxes = filtered_pred_boxes
179
+
180
 
181
  #--------------------------------------------------------------------------------------------
182
  # YOUR MODEL INFERENCE STOPS HERE