Spaces:
Sleeping
Sleeping
Update tasks/image.py
Browse files- tasks/image.py +4 -6
tasks/image.py
CHANGED
@@ -163,7 +163,7 @@ def compute_max_iou(true_boxes, pred_box):
|
|
163 |
|
164 |
@router.post(ROUTE, tags=["Image Task"],
|
165 |
description=DESCRIPTION)
|
166 |
-
async def evaluate_image(model_path: str = "
|
167 |
"""
|
168 |
Evaluate image classification and object detection for forest fire smoke.
|
169 |
|
@@ -184,10 +184,8 @@ async def evaluate_image(model_path: str = "models/yolo11s_best.pt", request: Im
|
|
184 |
# Split dataset
|
185 |
train_test = dataset["train"]
|
186 |
test_dataset = dataset["val"]
|
187 |
-
|
188 |
-
|
189 |
-
if("detr" in model_path):
|
190 |
-
model = RTDETR(model_path)
|
191 |
|
192 |
# Start tracking emissions
|
193 |
tracker.start()
|
@@ -203,7 +201,7 @@ async def evaluate_image(model_path: str = "models/yolo11s_best.pt", request: Im
|
|
203 |
pred_boxes = []
|
204 |
true_boxes_list = [] # List of lists, each inner list contains boxes for one image
|
205 |
|
206 |
-
for example in
|
207 |
# Parse true annotation (YOLO format: class_id x_center y_center width height)
|
208 |
annotation = example.get("annotations", "").strip()
|
209 |
has_smoke = len(annotation) > 0
|
|
|
163 |
|
164 |
@router.post(ROUTE, tags=["Image Task"],
|
165 |
description=DESCRIPTION)
|
166 |
+
async def evaluate_image(model_path: str = "models_v3/yolo11s_best.engine", request: ImageEvaluationRequest = ImageEvaluationRequest()):
|
167 |
"""
|
168 |
Evaluate image classification and object detection for forest fire smoke.
|
169 |
|
|
|
184 |
# Split dataset
|
185 |
train_test = dataset["train"]
|
186 |
test_dataset = dataset["val"]
|
187 |
+
|
188 |
+
model = YOLO(model_path, task="detect")
|
|
|
|
|
189 |
|
190 |
# Start tracking emissions
|
191 |
tracker.start()
|
|
|
201 |
pred_boxes = []
|
202 |
true_boxes_list = [] # List of lists, each inner list contains boxes for one image
|
203 |
|
204 |
+
for example in test_dataset:
|
205 |
# Parse true annotation (YOLO format: class_id x_center y_center width height)
|
206 |
annotation = example.get("annotations", "").strip()
|
207 |
has_smoke = len(annotation) > 0
|