|
import json |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import requests |
|
from config import Config |
|
from model import BirdAST |
|
import torch |
|
import librosa |
|
import noisereduce as nr |
|
import timm |
|
from typing import Iterable |
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
import time |
|
import pandas as pd |
|
from classpred import predict_class |
|
import torch.nn.functional as F |
|
import random |
|
from torchaudio.compliance import kaldi |
|
from torchaudio.functional import resample |
|
from transformers import ASTFeatureExtractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEATURE_EXTRACTOR = ASTFeatureExtractor() |
|
|
|
def plot_mel(sr, x): |
|
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=128, fmax=10000) |
|
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) |
|
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) |
|
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) |
|
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) |
|
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) |
|
return fig |
|
|
|
def plot_wave(sr, x): |
|
ry = nr.reduce_noise(y=x, sr=sr) |
|
fig, ax = plt.subplots(2, 1, figsize=(12, 8)) |
|
|
|
|
|
librosa.display.waveshow(x, sr=sr, ax=ax[0]) |
|
ax[0].set(title='Original Waveform') |
|
ax[0].set_xlabel('Time (s)') |
|
ax[0].set_ylabel('Amplitude') |
|
|
|
|
|
librosa.display.waveshow(ry, sr=sr, ax=ax[1]) |
|
ax[1].set(title='Noise Reduced Waveform') |
|
ax[1].set_xlabel('Time (s)') |
|
ax[1].set_ylabel('Amplitude') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def predict(audio, start, end): |
|
sr, x = audio |
|
|
|
x = np.array(x, dtype=np.float32)/32768.0 |
|
x = x[int(start*sr) : int(end*sr)] |
|
res = preprocess_for_inference(x, sr) |
|
|
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if x.shape[0] < start * sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") |
|
|
|
if x.shape[0] > end * sr: |
|
end = x.shape[0]/(1.0*sr) |
|
|
|
fig1 = plot_mel(sr, x) |
|
fig2 = plot_wave(sr, x) |
|
|
|
|
|
return predict_class(x, sr, start, end), res, fig1, fig2 |
|
|
|
def download_model(url, model_path): |
|
if not os.path.exists(model_path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
model_urls = [f'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] |
|
model_paths = [f'BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)] |
|
|
|
for (model_url, model_path) in zip(model_urls, model_paths): |
|
download_model(model_url, model_path) |
|
|
|
|
|
eval_models = [BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') for i in range(5)] |
|
state_dicts = [torch.load(f'BirdAST_Baseline_GroupKFold_fold_{i}.pth', map_location='cpu') for i in range(5)] |
|
for idx, sd in enumerate(state_dicts): |
|
eval_models[idx].load_state_dict(sd) |
|
|
|
|
|
for i in range(5): |
|
eval_models[i].eval() |
|
|
|
|
|
label_mapping = pd.read_csv('BirdAST_Baseline_GroupKFold_label_map.csv') |
|
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} |
|
|
|
def preprocess_for_inference(audio_arr, sr): |
|
print(sr) |
|
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") |
|
input_values = spec['input_values'] |
|
|
|
|
|
|
|
model_outputs = [] |
|
|
|
with torch.no_grad(): |
|
|
|
for model in eval_models: |
|
output = model(input_values) |
|
predict_score = F.softmax(output['logits'], dim=1) |
|
model_outputs.append(predict_score) |
|
print(predict_score[0, 434]) |
|
|
|
|
|
|
|
avg_predictions = torch.mean(torch.cat(model_outputs), dim=0) |
|
print(avg_predictions[434]) |
|
|
|
|
|
topk_values, topk_indices = torch.topk(avg_predictions, 10) |
|
print(topk_values.shape, topk_indices.shape) |
|
|
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices, topk_values): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item()*100 |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
DESCRIPTION = """ |
|
# Introduction |
|
|
|
It is esimated that 50% of the global economy is threatened by biodiversity loss [2]. As such, intensive efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is |
|
finding the bird species in a region using bird species audio classification. |
|
|
|
# Solution |
|
|
|
To tackle this problem, we propose VOJ. It first preprocesses an audio signal using a bandpass filter (1K - 8K) and then applies downsampling to 16K Hz. Afterwards, we input the signal into AudioMAE (Audio Masked AutoEncoder by Meta [1]) which extracts relevant features even in the presence of corruptions to the signal spectrogram. |
|
The AudioMAE is also trained on 527 types of audio that comprise bird, silence, environmental noise, and other types. The purpose of this initial inference stage is to provide an initial sense of the audio. If the AudioMAE outputs silence, we can expect low species prediction confidence, or if the output is insect, it may not be worth labelling. |
|
Next, we train BirdAST, which has Audio Spectrogram Transformer (AST) as backbone, followed by an attention pooling and dense layer. We also train EfficientB0 on the melspectrogram, and finally, we train a model using Wav2Vec pretrained on 50 bird species [3]. |
|
""" |
|
|
|
|
|
css = """ |
|
#gradio-animation { |
|
font-size: 2em; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.logo-container img { |
|
width: 14%; /* Adjust width as necessary */ |
|
display: block; |
|
margin: auto; |
|
} |
|
|
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
class Seafoam(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.emerald, |
|
secondary_hue: colors.Color | str = colors.blue, |
|
neutral_hue: colors.Color | str = colors.gray, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_lg, |
|
font: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("Quicksand"), |
|
"ui-sans-serif", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-monospace", |
|
"monospace", |
|
), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
|
|
seafoam = Seafoam() |
|
|
|
|
|
js = """ |
|
function createGradioAnimation() { |
|
var container = document.getElementById('gradio-animation'); |
|
var text = 'Voice of Jungle'; |
|
for (var i = 0; i < text.length; i++) { |
|
(function(i){ |
|
setTimeout(function(){ |
|
var letter = document.createElement('span'); |
|
letter.style.opacity = '0'; |
|
letter.style.transition = 'opacity 0.5s'; |
|
letter.innerText = text[i]; |
|
container.appendChild(letter); |
|
setTimeout(function() { |
|
letter.style.opacity = '1'; |
|
}, 50); |
|
}, i * 250); |
|
})(i); |
|
} |
|
} |
|
""" |
|
|
|
REFERENCES = """ |
|
References |
|
|
|
[1] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS. |
|
|
|
[2] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/. |
|
|
|
[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection |
|
""" |
|
with gr.Blocks(theme = seafoam, css = css, js = js) as demo: |
|
|
|
gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" width="50px" alt="vojlogo"></div>') |
|
gr.Markdown('<div id="gradio-animation"></div>') |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["1094_Pionus_fuscus_2.wav", 0, 10], |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
gr.Markdown(REFERENCES) |
|
|
|
demo.launch(share = True) |
|
|
|
|
|
|