|
import json |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import requests |
|
from config import Config |
|
from model import BirdAST |
|
import torch |
|
import librosa |
|
import noisereduce as nr |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
import random |
|
from torchaudio.compliance import kaldi |
|
from torchaudio.functional import resample |
|
from transformers import ASTFeatureExtractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEATURE_EXTRACTOR = ASTFeatureExtractor() |
|
|
|
def plot_mel(sr, x): |
|
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=224, fmax=10000) |
|
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) |
|
mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) |
|
mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) |
|
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) |
|
librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) |
|
return fig |
|
|
|
def plot_wave(sr, x): |
|
ry = nr.reduce_noise(y=x, sr=sr) |
|
fig, ax = plt.subplots(2, 1, figsize=(12, 8)) |
|
|
|
|
|
librosa.display.waveshow(x, sr=sr, ax=ax[0]) |
|
ax[0].set(title='Original Waveform') |
|
ax[0].set_xlabel('Time (s)') |
|
ax[0].set_ylabel('Amplitude') |
|
|
|
|
|
librosa.display.waveshow(ry, sr=sr, ax=ax[1]) |
|
ax[1].set(title='Noise Reduced Waveform') |
|
ax[1].set_xlabel('Time (s)') |
|
ax[1].set_ylabel('Amplitude') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def predict(audio, start, end): |
|
sr, x = audio |
|
x = np.array(x, dtype=np.float64) |
|
res = preprocess_for_inference(x, sr) |
|
|
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if x.shape[0] < start * sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") |
|
|
|
if x.shape[0] > end * sr: |
|
end = x.shape[0]/(1.0*sr) |
|
|
|
fig1 = plot_mel(sr, x) |
|
fig2 = plot_wave(sr, x) |
|
|
|
|
|
return res, res, fig1, fig2 |
|
|
|
def download_model(url, model_path): |
|
if not os.path.exists(model_path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
model_url = 'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_fold_1.pth' |
|
model_path = 'BirdAST_Baseline_fold_1.pth' |
|
download_model(model_url, model_path) |
|
|
|
|
|
eval_model = BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') |
|
state_dict = torch.load('BirdAST_Baseline_fold_1.pth', map_location='cpu') |
|
eval_model.load_state_dict(state_dict) |
|
|
|
|
|
eval_model.eval() |
|
|
|
|
|
label_mapping = pd.read_csv('label_mapping.csv') |
|
species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} |
|
|
|
|
|
def preprocess_for_inference(audio_arr, sr): |
|
spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") |
|
input_values = spec['input_values'] |
|
|
|
|
|
results = [] |
|
with torch.no_grad(): |
|
|
|
output = eval_model(input_values) |
|
predict_score = F.softmax(output['logits'], dim=1) |
|
|
|
|
|
topk_values, topk_indices = torch.topk(predict_score, 10, dim=1) |
|
|
|
|
|
for idx, scores in zip(topk_indices[0], topk_values[0]): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item() |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
DESCRIPTION = """ |
|
Bird audio classification using SOTA Voice of Jungle Technology. |
|
""" |
|
|
|
""" |
|
with gr.Blocks() as demo: |
|
submit_btn = gr.Button("Submit") |
|
demo = gr.Interface( |
|
title="Bird audio classification", |
|
description=DESCRIPTION, |
|
fn=predict, |
|
inputs=["audio", "number", "number"], |
|
outputs=[ |
|
gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"), |
|
gr.Plot(label="waveform"), |
|
gr.Plot(label="spectrogram"), |
|
], |
|
examples=[ |
|
["312_Cissopis_leverinia_1.wav", 0, 5], |
|
["1094_Pionus_fuscus_2.wav", 0, 10], |
|
], |
|
) |
|
""" |
|
|
|
|
|
css = """ |
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
with gr.Blocks(css = css) as demo: |
|
gr.Markdown("# Bird Species Audio Classification") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=1, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["312_Cissopis_leverinia_1.wav", 0, 5], |
|
["1094_Pionus_fuscus_2.wav", 0, 10], |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
demo.launch(share = True) |