Birdset Dataset and Models
Collection
12 items
โข
Updated
Efficient trained on the XCL dataset from BirdSet, covering 9736 bird species from Xeno-Canto. Please refer to the BirdSet Paper and the BirdSet Repository for further information.
The BirdSet data needs a custom processor that is available in the BirdSet repository. The model does not have a processor available.
The model accepts a mono image (spectrogram) as input (e.g., torch.Size([16, 1, 256, 417])
)
See model implementation. Run in Google Colab:
from transformers import EfficientNetForImageClassification
import torch
import torchaudio
from torchvision import transforms
import requests
import torchaudio
import io
# download the audio file of a bird sound: Common Craw
url = "https://xeno-canto.org/704485/download"
response = requests.get(url)
audio, sample_rate = torchaudio.load(io.BytesIO(response.content))
print("Original shape and sample rate: ", audio.shape, sample_rate)
# crop to 5 seconds
audio = audio[:, : 5 * sample_rate]
# resample to 32kHz
resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
audio = resample(audio)
print("Resampled shape and sample rate: ", audio.shape, 32000)
CACHE_DIR = "../../data_birdset" # Change this to your own cache directory
# Load the model
model = EfficientNetForImageClassification.from_pretrained(
"DBD-research-group/EfficientNet-B1-BirdSet-XCL",
num_channels=1,
cache_dir=CACHE_DIR,
ignore_mismatched_sizes=True,
)
class PowerToDB(torch.nn.Module):
"""
A power spectrogram to decibel conversion layer. See birdset.datamodule.components.augmentations
"""
def __init__(self, ref=1.0, amin=1e-10, top_db=80.0):
super(PowerToDB, self).__init__()
# Initialize parameters
self.ref = ref
self.amin = amin
self.top_db = top_db
def forward(self, S):
# Convert S to a PyTorch tensor if it is not already
S = torch.as_tensor(S, dtype=torch.float32)
if self.amin <= 0:
raise ValueError("amin must be strictly positive")
if torch.is_complex(S):
magnitude = S.abs()
else:
magnitude = S
# Check if ref is a callable function or a scalar
if callable(self.ref):
ref_value = self.ref(magnitude)
else:
ref_value = torch.abs(torch.tensor(self.ref, dtype=S.dtype))
# Compute the log spectrogram
log_spec = 10.0 * torch.log10(
torch.maximum(magnitude, torch.tensor(self.amin, device=magnitude.device))
)
log_spec -= 10.0 * torch.log10(
torch.maximum(ref_value, torch.tensor(self.amin, device=magnitude.device))
)
# Apply top_db threshold if necessary
if self.top_db is not None:
if self.top_db < 0:
raise ValueError("top_db must be non-negative")
log_spec = torch.maximum(log_spec, log_spec.max() - self.top_db)
return log_spec
def preprocess(audio, sample_rate_of_audio):
"""
Preprocess the audio to the format that the model expects
- Resample to 32kHz
- Convert to melscale spectrogram n_fft: 2048, hop_length: 256, power: 2. melscale: n_mels: 256, n_stft: 1025
- Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
"""
powerToDB = PowerToDB()
# Resample to 32kHz
resample = torchaudio.transforms.Resample(
orig_freq=sample_rate_of_audio, new_freq=32000
)
audio = resample(audio)
spectrogram = torchaudio.transforms.Spectrogram(
n_fft=2048, hop_length=256, power=2.0
)(audio)
melspec = torchaudio.transforms.MelScale(n_mels=256, n_stft=1025)(spectrogram)
dbscale = powerToDB(melspec)
normalized_dbscale = transforms.Normalize((-4.268,), (4.569,))(dbscale)
return normalized_dbscale
preprocessed_audio = preprocess(audio, sample_rate)
print("Preprocessed_audio shape:", preprocessed_audio.shape)
logits = model(preprocessed_audio.unsqueeze(0)).logits
print("Logits shape: ", logits.shape)
top5 = torch.topk(logits, 5)
print("Top 5 logits:", top5.values)
print("Top 5 predicted classes:")
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])