ImenMourali's picture
Update tasks/audio.py
6327a1b verified
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import librosa
import os
import tarfile
from tensorflow.keras.models import load_model
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
load_dotenv()
router = APIRouter()
DESCRIPTION = "Random Baseline"
ROUTE = "/audio"
# Define paths for local model files
YAMNET_TAR_PATH = "./yamnet-tensorflow2-yamnet-v1.tar.gz" # Ensure this is in the correct directory
EXTRACT_PATH = "./yamnet_model"
CLASSIFIER_PATH = "./audio_model.h5"
# Extract YAMNet if it is not already extracted
if not os.path.exists(EXTRACT_PATH):
with tarfile.open(YAMNET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
# Load YAMNet
yamnet = hub.load(EXTRACT_PATH)
# Load trained classifier
audio_model = load_model(CLASSIFIER_PATH)
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
"""Inference function to classify audio samples using a pre-trained model."""
# Load dataset
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
test_dataset = dataset["test"]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
predictions = []
for audio_data in test_dataset["audio"]:
# Extract waveform and sampling rate
waveform = audio_data["array"]
sample_rate = audio_data["sampling_rate"]
# Resample if needed
if sample_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
# Convert to tensor
waveform = tf.convert_to_tensor(waveform, dtype=tf.float32)
waveform = tf.squeeze(waveform) # Ensure waveform is 1D
# Extract embeddings from YAMNet
_, embeddings, _ = yamnet(waveform)
embeddings = tf.reduce_mean(embeddings, axis=0).numpy() # Average over time
# Reshape embeddings for classifier input
embeddings = embeddings.reshape(1, -1)
# Predict using the trained classifier
scores = audio_model.predict(embeddings)
predicted_class_index = np.argmax(scores)
predicted_class_label = "chainsaw" if predicted_class_index == 0 else "environment"
predictions.append(predicted_class_label)
# Map string predictions to numeric labels
numeric_predictions = [0 if pred == "chainsaw" else 1 for pred in predictions]
true_labels = test_dataset["label"]
accuracy = accuracy_score(true_labels, numeric_predictions)
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Prepare results
results = {
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results