voj / app.py
amroa's picture
themes
36dbf7a
raw
history blame
11 kB
import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
from config import Config
from model import BirdAST
import torch
import librosa
import noisereduce as nr
import timm
from typing import Iterable
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time
import pandas as pd
from classpred import predict_class
import torch.nn.functional as F
import random
from torchaudio.compliance import kaldi
from torchaudio.functional import resample
from transformers import ASTFeatureExtractor
#TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
#MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
#LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
#AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
FEATURE_EXTRACTOR = ASTFeatureExtractor()
def plot_mel(sr, x):
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) # normalize spectrogram to [0,1]
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) # Convert to 3-channel
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax)
return fig
def plot_wave(sr, x):
ry = nr.reduce_noise(y=x, sr=sr)
fig, ax = plt.subplots(2, 1, figsize=(12, 8))
# Plot the original waveform
librosa.display.waveshow(x, sr=sr, ax=ax[0])
ax[0].set(title='Original Waveform')
ax[0].set_xlabel('Time (s)')
ax[0].set_ylabel('Amplitude')
# Plot the noise-reduced waveform
librosa.display.waveshow(ry, sr=sr, ax=ax[1])
ax[1].set(title='Noise Reduced Waveform')
ax[1].set_xlabel('Time (s)')
ax[1].set_ylabel('Amplitude')
plt.tight_layout()
return fig
def predict(audio, start, end):
sr, x = audio
x = np.array(x, dtype=np.float32)/32768.0
x = x[int(start*sr) : int(end*sr)]
res = preprocess_for_inference(x, sr)
if start >= end:
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)")
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
if x.shape[0] > end * sr:
end = x.shape[0]/(1.0*sr)
fig1 = plot_mel(sr, x)
fig2 = plot_wave(sr, x)
return predict_class(x, sr, start, end), res, fig1, fig2
def download_model(url, model_path):
if not os.path.exists(model_path):
response = requests.get(url)
response.raise_for_status() # Ensure the request was successful
with open(model_path, 'wb') as f:
f.write(response.content)
# Model URL and path
model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)]
model_paths = [f'BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)]
for (model_url, model_path) in zip(model_urls, model_paths):
download_model(model_url, model_path)
# Load the model (assumes you have the model architecture defined)
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)]
state_dicts = [torch.load(f'BirdAST_Baseline_GroupKFold_fold_{i}.pth', map_location='cpu') for i in range(5)]
for idx, sd in enumerate(state_dicts):
eval_models[idx].load_state_dict(sd)
# Set to evaluation mode
for i in range(5):
eval_models[i].eval()
# Load the species mapping
label_mapping = pd.read_csv('BirdAST_Baseline_GroupKFold_label_map.csv')
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()}
def preprocess_for_inference(audio_arr, sr):
print(sr)
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt")
input_values = spec['input_values'] # Get the input values prepared for model input
# Initialize a list to store predictions from all models
model_outputs = []
with torch.no_grad():
# Accumulate predictions from each model
for model in eval_models:
output = model(input_values)
predict_score = F.softmax(output['logits'], dim=1)
model_outputs.append(predict_score)
print(predict_score[0, 434])
# Average the predictions across all models
avg_predictions = torch.mean(torch.cat(model_outputs), dim=0) #.values
print(avg_predictions[434])
# Get the top 10 predictions based on the average prediction scores
topk_values, topk_indices = torch.topk(avg_predictions, 10)
print(topk_values.shape, topk_indices.shape)
# Initialize results list to store the species names and their associated probabilities
results = []
for idx, scores in zip(topk_indices, topk_values):
species_name = species_id_to_name[idx.item()]
probability = scores.item()*100
results.append([species_name, probability])
return results
DESCRIPTION = """
# Introduction
It is esimated that 50% of the global economy is threatened by biodiversity loss [2]. As such, intensive efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is
finding the bird species in a region using bird species audio classification.
# Solution
To tackle this problem, we propose VOJ. It first preprocesses an audio signal using a bandpass filter (1K - 8K) and then applies downsampling to 16K Hz. Afterwards, we input the signal into AudioMAE (Audio Masked AutoEncoder by Meta [1]) which extracts relevant features even in the presence of corruptions to the signal spectrogram.
The AudioMAE is also trained on 527 types of audio that comprise bird, silence, environmental noise, and other types. The purpose of this initial inference stage is to provide an initial sense of the audio. If the AudioMAE outputs silence, we can expect low species prediction confidence, or if the output is insect, it may not be worth labelling.
Next, we train BirdAST, which has Audio Spectrogram Transformer (AST) as backbone, followed by an attention pooling and dense layer. We also train EfficientB0 on the melspectrogram, and finally, we train a model using Wav2Vec pretrained on 50 bird species [3].
"""
css = """
#gradio-animation {
font-size: 2em;
font-weight: bold;
text-align: center;
margin-bottom: 20px;
}
.logo-container img {
width: 14%; /* Adjust width as necessary */
display: block;
margin: auto;
}
.number-input {
height: 100%;
padding-bottom: 60px; /* Adust the value as needed for more or less space */
}
.full-height {
height: 100%;
}
.column-container {
height: 100%;
}
"""
class Seafoam(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.emerald,
secondary_hue: colors.Color | str = colors.blue,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Quicksand"),
"ui-sans-serif",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
seafoam = Seafoam()
js = """
function createGradioAnimation() {
var container = document.getElementById('gradio-animation');
var text = 'Voice of Jungle';
for (var i = 0; i < text.length; i++) {
(function(i){
setTimeout(function(){
var letter = document.createElement('span');
letter.style.opacity = '0';
letter.style.transition = 'opacity 0.5s';
letter.innerText = text[i];
container.appendChild(letter);
setTimeout(function() {
letter.style.opacity = '1';
}, 50);
}, i * 250);
})(i);
}
}
"""
REFERENCES = """
References
[1] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS.
[2] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/.
[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection
"""
with gr.Blocks(theme = seafoam, css = css, js = js) as demo:
gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" width="50px" alt="vojlogo"></div>')
gr.Markdown('<div id="gradio-animation"></div>')
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(elem_classes="column-container"):
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height")
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height")
with gr.Column():
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height")
with gr.Row():
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction")
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction")
with gr.Row():
waveform_output = gr.Plot(label="Waveform")
spectrogram_output = gr.Plot(label="Spectrogram")
gr.Examples(
examples=[
["1094_Pionus_fuscus_2.wav", 0, 10],
],
inputs=[audio_input, start_time_input, end_time_input]
)
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output])
gr.Markdown(REFERENCES)
demo.launch(share = True)
## logo: <img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" alt="vojlogo" border="0">
## cactus: <img src="https://i.ibb.co/3sW2mJN/spur.jpg" alt="spur" border="0">