--- library_name: transformers tags: [] --- # ConvNext (trained on XCL from BirdSet) ConvNext trained on the XCL dataset from BirdSet, covering 9736 bird species from Xeno-Canto. Please refer to the [BirdSet Paper](https://arxiv.org/pdf/2403.10380) and the [BirdSet Repository](https://github.com/DBD-research-group/BirdSet/tree/main) for further information. ### Model Details ConvNeXT is a pure convolutional model (ConvNet), inspired by the design of Vision Transformers, that claims to outperform them. ## How to use The BirdSet data needs a custom processor that is available in the BirdSet repository. The model does not have a processor available. The model accepts a mono image (spectrogram) as input (e.g., `torch.Size([16, 1, 128, 334])`) - The model is trained on 5-second clips of bird vocalizations. - num_channels: 1 - pretrained checkpoint: facebook/convnext-base-224-22k - sampling_rate: 32_000 - normalize spectrogram: mean: -4.268, std: 4.569 (from esc-50) - spectrogram: n_fft: 1024, hop_length: 320, power: 2.0 - melscale: n_mels: 128, n_stft: 513 - dbscale: top_db: 80 See [example inference notebook](https://github.com/DBD-research-group/BirdSet/blob/main/notebooks/tutorials/model_inference.ipynb). Run in [Google Colab](https://colab.research.google.com/drive/1pp_RCJEjSR4gPBGFtxDdgnr4Uk1_KimU?usp=sharing): ```python from transformers import ConvNextForImageClassification import torch import torchaudio from torchvision import transforms import requests import torchaudio import io # download the audio file of a bird sound: Common Craw url = "https://xeno-canto.org/704485/download" response = requests.get(url) audio, sample_rate = torchaudio.load(io.BytesIO(response.content)) print("Original shape and sample rate: ", audio.shape, sample_rate) # crop to 5 seconds audio = audio[:, : 5 * sample_rate] # resample to 32kHz resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000) audio = resample(audio) print("Resampled shape and sample rate: ", audio.shape, 32000) CACHE_DIR = "../../data_birdset" # Change this to your own cache directory # Load the model model = ConvNextForImageClassification.from_pretrained( "DBD-research-group/ConvNeXT-Base-BirdSet-XCL", cache_dir=CACHE_DIR, ignore_mismatched_sizes=True, ) class PowerToDB(torch.nn.Module): """ A power spectrogram to decibel conversion layer. See birdset.datamodule.components.augmentations """ def __init__(self, ref=1.0, amin=1e-10, top_db=80.0): super(PowerToDB, self).__init__() # Initialize parameters self.ref = ref self.amin = amin self.top_db = top_db def forward(self, S): # Convert S to a PyTorch tensor if it is not already S = torch.as_tensor(S, dtype=torch.float32) if self.amin <= 0: raise ValueError("amin must be strictly positive") if torch.is_complex(S): magnitude = S.abs() else: magnitude = S # Check if ref is a callable function or a scalar if callable(self.ref): ref_value = self.ref(magnitude) else: ref_value = torch.abs(torch.tensor(self.ref, dtype=S.dtype)) # Compute the log spectrogram log_spec = 10.0 * torch.log10( torch.maximum(magnitude, torch.tensor(self.amin, device=magnitude.device)) ) log_spec -= 10.0 * torch.log10( torch.maximum(ref_value, torch.tensor(self.amin, device=magnitude.device)) ) # Apply top_db threshold if necessary if self.top_db is not None: if self.top_db < 0: raise ValueError("top_db must be non-negative") log_spec = torch.maximum(log_spec, log_spec.max() - self.top_db) return log_spec def preprocess(audio, sample_rate_of_audio): """ Preprocess the audio to the format that the model expects - Resample to 32kHz - Convert to melscale spectrogram n_fft: 1024, hop_length: 320, power: 2. melscale: n_mels: 128, n_stft: 513 - Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet) """ powerToDB = PowerToDB() # Resample to 32kHz resample = torchaudio.transforms.Resample( orig_freq=sample_rate_of_audio, new_freq=32000 ) audio = resample(audio) spectrogram = torchaudio.transforms.Spectrogram( n_fft=1024, hop_length=320, power=2.0 )(audio) melspec = torchaudio.transforms.MelScale(n_mels=128, n_stft=513)(spectrogram) dbscale = powerToDB(melspec) normalized_dbscale = transforms.Normalize((-4.268,), (4.569,))(dbscale) return normalized_dbscale preprocessed_audio = preprocess(audio, sample_rate) logits = model(preprocessed_audio.unsqueeze(0)).logits print("Logits shape: ", logits.shape) top5 = torch.topk(logits, 5) print("Top 5 logits:", top5.values) print("Top 5 predicted classes:") print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()]) ``` ## Model Source - **Repository:** [BirdSet Repository](https://github.com/DBD-research-group/BirdSet/tree/main) - **Paper [optional]:** [BirdSet Paper](https://arxiv.org/pdf/2403.10380) ## Citation