Update tasks/audio.py
Browse files- tasks/audio.py +30 -9
tasks/audio.py
CHANGED
@@ -239,18 +239,39 @@ for audio_data in test_dataset["audio"]:
|
|
239 |
|
240 |
print("Predictions:", predictions)
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
-
|
244 |
#--------------------------------------------------------------------------------------------
|
245 |
# YOUR MODEL INFERENCE STOPS HERE
|
246 |
-
#--------------------------------------------------------------------------------------------
|
247 |
-
|
248 |
# Stop tracking emissions
|
249 |
emissions_data = tracker.stop_task()
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
# Prepare results dictionary
|
255 |
results = {
|
256 |
"username": username,
|
@@ -268,5 +289,5 @@ print("Predictions:", predictions)
|
|
268 |
"test_seed": request.test_seed
|
269 |
}
|
270 |
}
|
271 |
-
|
272 |
-
return results
|
|
|
239 |
|
240 |
print("Predictions:", predictions)
|
241 |
|
242 |
+
def map_predictions_to_labels(predictions):
|
243 |
+
"""
|
244 |
+
Maps string predictions to numeric labels:
|
245 |
+
- "chainsaw" -> 0
|
246 |
+
- any other class -> 1
|
247 |
+
Args:
|
248 |
+
predictions (list of str): List of class name predictions.
|
249 |
+
Returns:
|
250 |
+
list of int: Mapped numeric labels.
|
251 |
+
"""
|
252 |
+
return [0 if pred == "chainsaw" else 1 for pred in predictions]
|
253 |
+
|
254 |
+
from sklearn.metrics import accuracy_score
|
255 |
+
|
256 |
+
# Map string predictions to numeric labels
|
257 |
+
numeric_predictions = map_predictions_to_labels(predictions)
|
258 |
+
|
259 |
+
# Extract true labels (already numeric)
|
260 |
+
true_labels = test_dataset["label"]
|
261 |
+
|
262 |
+
# Calculate accuracy
|
263 |
+
accuracy = accuracy_score(true_labels, numeric_predictions)
|
264 |
+
print("Accuracy:", accuracy)
|
265 |
|
|
|
266 |
#--------------------------------------------------------------------------------------------
|
267 |
# YOUR MODEL INFERENCE STOPS HERE
|
268 |
+
#--------------------------------------------------------------------------------------------
|
269 |
+
|
270 |
# Stop tracking emissions
|
271 |
emissions_data = tracker.stop_task()
|
272 |
+
|
273 |
+
|
274 |
+
|
|
|
275 |
# Prepare results dictionary
|
276 |
results = {
|
277 |
"username": username,
|
|
|
289 |
"test_seed": request.test_seed
|
290 |
}
|
291 |
}
|
292 |
+
|
293 |
+
return results
|