File size: 3,453 Bytes
6327a1b
 
 
 
 
 
 
4d6e8c2
fe4a4cb
 
 
4d6e8c2
fe4a4cb
4768d6b
6327a1b
3b09640
 
4d6e8c2
70f5f26
1c33274
70f5f26
6327a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d6e8c2
6327a1b
 
adf98b6
398ac10
adf98b6
fe4a4cb
 
 
adf98b6
 
 
 
 
 
b321cd2
6327a1b
adf98b6
 
b321cd2
6327a1b
adf98b6
6327a1b
b321cd2
6327a1b
 
 
b321cd2
6327a1b
 
b321cd2
6327a1b
 
adf98b6
6327a1b
 
b321cd2
adf98b6
6327a1b
adf98b6
 
 
 
 
 
6327a1b
adf98b6
 
 
 
 
 
 
 
 
 
 
 
 
fe4a4cb
2a198b3
6327a1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import librosa
import os
import tarfile
from tensorflow.keras.models import load_model
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv 

load_dotenv()

router = APIRouter()
DESCRIPTION = "Random Baseline"
ROUTE = "/audio"

# Define paths for local model files
YAMNET_TAR_PATH = "./yamnet-tensorflow2-yamnet-v1.tar.gz"  # Ensure this is in the correct directory
EXTRACT_PATH = "./yamnet_model"
CLASSIFIER_PATH = "./audio_model.h5"

# Extract YAMNet if it is not already extracted
if not os.path.exists(EXTRACT_PATH):
    with tarfile.open(YAMNET_TAR_PATH, "r:gz") as tar:
        tar.extractall(EXTRACT_PATH)

# Load YAMNet
yamnet = hub.load(EXTRACT_PATH)

# Load trained classifier
audio_model = load_model(CLASSIFIER_PATH)


@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
    """Inference function to classify audio samples using a pre-trained model."""
    # Load dataset
    dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
    test_dataset = dataset["test"]

    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")

    predictions = []
    for audio_data in test_dataset["audio"]:
        # Extract waveform and sampling rate
        waveform = audio_data["array"]
        sample_rate = audio_data["sampling_rate"]

        # Resample if needed
        if sample_rate != 16000:
            waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)

        # Convert to tensor
        waveform = tf.convert_to_tensor(waveform, dtype=tf.float32)
        waveform = tf.squeeze(waveform)  # Ensure waveform is 1D

        # Extract embeddings from YAMNet
        _, embeddings, _ = yamnet(waveform)
        embeddings = tf.reduce_mean(embeddings, axis=0).numpy()  # Average over time

        # Reshape embeddings for classifier input
        embeddings = embeddings.reshape(1, -1)

        # Predict using the trained classifier
        scores = audio_model.predict(embeddings)
        predicted_class_index = np.argmax(scores)
        predicted_class_label = "chainsaw" if predicted_class_index == 0 else "environment"
        predictions.append(predicted_class_label)

    # Map string predictions to numeric labels
    numeric_predictions = [0 if pred == "chainsaw" else 1 for pred in predictions]
    true_labels = test_dataset["label"]
    accuracy = accuracy_score(true_labels, numeric_predictions)

    # Stop tracking emissions
    emissions_data = tracker.stop_task()

    # Prepare results
    results = {
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }

    return results