ImenMourali commited on
Commit
f19a289
·
verified ·
1 Parent(s): 8169d28

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +30 -9
tasks/audio.py CHANGED
@@ -239,18 +239,39 @@ for audio_data in test_dataset["audio"]:
239
 
240
  print("Predictions:", predictions)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
-
244
  #--------------------------------------------------------------------------------------------
245
  # YOUR MODEL INFERENCE STOPS HERE
246
- #--------------------------------------------------------------------------------------------
247
-
248
  # Stop tracking emissions
249
  emissions_data = tracker.stop_task()
250
-
251
- # Calculate accuracy
252
- accuracy = accuracy_score(true_labels, predictions)
253
-
254
  # Prepare results dictionary
255
  results = {
256
  "username": username,
@@ -268,5 +289,5 @@ print("Predictions:", predictions)
268
  "test_seed": request.test_seed
269
  }
270
  }
271
-
272
- return results
 
239
 
240
  print("Predictions:", predictions)
241
 
242
+ def map_predictions_to_labels(predictions):
243
+ """
244
+ Maps string predictions to numeric labels:
245
+ - "chainsaw" -> 0
246
+ - any other class -> 1
247
+ Args:
248
+ predictions (list of str): List of class name predictions.
249
+ Returns:
250
+ list of int: Mapped numeric labels.
251
+ """
252
+ return [0 if pred == "chainsaw" else 1 for pred in predictions]
253
+
254
+ from sklearn.metrics import accuracy_score
255
+
256
+ # Map string predictions to numeric labels
257
+ numeric_predictions = map_predictions_to_labels(predictions)
258
+
259
+ # Extract true labels (already numeric)
260
+ true_labels = test_dataset["label"]
261
+
262
+ # Calculate accuracy
263
+ accuracy = accuracy_score(true_labels, numeric_predictions)
264
+ print("Accuracy:", accuracy)
265
 
 
266
  #--------------------------------------------------------------------------------------------
267
  # YOUR MODEL INFERENCE STOPS HERE
268
+ #--------------------------------------------------------------------------------------------
269
+
270
  # Stop tracking emissions
271
  emissions_data = tracker.stop_task()
272
+
273
+
274
+
 
275
  # Prepare results dictionary
276
  results = {
277
  "username": username,
 
289
  "test_seed": request.test_seed
290
  }
291
  }
292
+
293
+ return results