satvs commited on
Commit
fc60ded
·
1 Parent(s): cd8c2b8

Optimize submission

Browse files
Files changed (1) hide show
  1. tasks/image.py +20 -28
tasks/image.py CHANGED
@@ -18,7 +18,6 @@ from pathlib import Path
18
  from ultralytics import YOLO
19
  from torch import device
20
  from torch.cuda import is_available
21
- from torch import no_grad
22
 
23
  router = APIRouter()
24
 
@@ -120,38 +119,31 @@ async def evaluate_image(request: ImageEvaluationRequest):
120
  true_labels = []
121
  pred_boxes = []
122
  true_boxes_list = [] # List of lists, each inner list contains boxes for one image
 
 
 
 
 
 
 
123
 
124
- # Preprocess annotations before the loop
125
- preprocessed_annotations = [parse_boxes(example.get("annotations", "").strip()) for example in test_dataset]
126
-
127
- # Use torch.no_grad() to disable gradient tracking during inference
128
- with no_grad():
129
- predictions = []
130
- true_labels = []
131
- pred_boxes = []
132
- true_boxes_list = [] # List of lists, each inner list contains boxes for one image
133
 
134
- logging.info(f"Inference start on device: {device_name}")
135
- for idx, example in enumerate(test_dataset):
136
- annotation = preprocessed_annotations[idx]
137
- has_smoke = len(annotation) > 0
138
- true_labels.append(int(has_smoke))
139
-
140
- # Make prediction for the current image
141
- results = model.predict(example["image"], device=device_name, conf=THRESHOLD, verbose=False, imgsz=IMGSIZE)[0]
142
-
143
- pred_has_smoke = len(results) > 0
144
- predictions.append(int(pred_has_smoke))
145
-
146
- # If there's a true box, add it to the list
147
- if has_smoke:
148
- true_boxes_list.append(annotation) # True boxes are already preprocessed
149
 
150
- # Handle prediction boxes: Append first box (or default box if none detected)
151
- if results.boxes.cls.numel() != 0:
 
152
  pred_boxes.append(results.boxes[0].xywhn.tolist()[0])
153
  else:
154
- pred_boxes.append([0, 0, 0, 0])
155
 
156
  #--------------------------------------------------------------------------------------------
157
  # YOUR MODEL INFERENCE STOPS HERE
 
18
  from ultralytics import YOLO
19
  from torch import device
20
  from torch.cuda import is_available
 
21
 
22
  router = APIRouter()
23
 
 
119
  true_labels = []
120
  pred_boxes = []
121
  true_boxes_list = [] # List of lists, each inner list contains boxes for one image
122
+
123
+ logging.info(f"Inference start on device: {device_name}")
124
+ for example in test_dataset:
125
+ # Parse true annotation (YOLO format: class_id x_center y_center width height)
126
+ annotation = example.get("annotations", "").strip()
127
+ has_smoke = len(annotation) > 0
128
+ true_labels.append(int(has_smoke))
129
 
130
+ # Make prediction
131
+ results = model.predict(example["image"], device=device_name, conf=THRESHOLD, verbose=False, imgsz=IMGSIZE)[0]
132
+ pred_has_smoke = len(results) > 0
133
+ predictions.append(int(pred_has_smoke))
 
 
 
 
 
134
 
135
+ # If there's a true box, parse it and add box prediction
136
+ if has_smoke:
137
+ # Parse all true boxes from the annotation
138
+ image_true_boxes = parse_boxes(annotation)
139
+ true_boxes_list.append(image_true_boxes)
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Append only one bounding box if at least one fire is detected
142
+ # Note that multiple boxes could be appended
143
+ if results.boxes.cls.numel()!=0:
144
  pred_boxes.append(results.boxes[0].xywhn.tolist()[0])
145
  else:
146
+ pred_boxes.append([0,0,0,0])
147
 
148
  #--------------------------------------------------------------------------------------------
149
  # YOUR MODEL INFERENCE STOPS HERE