EfficientNet (trained on XCL from BirdSet)

Efficient trained on the XCL dataset from BirdSet, covering 9736 bird species from Xeno-Canto. Please refer to the BirdSet Paper and the BirdSet Repository for further information.

How to use

The BirdSet data needs a custom processor that is available in the BirdSet repository. The model does not have a processor available. The model accepts a mono image (spectrogram) as input (e.g., torch.Size([16, 1, 256, 417]))

  • The model is trained on 5-second clips of bird vocalizations.
  • num_channels: 1
  • pretrained checkpoint: google/efficientnet-b1
  • sampling_rate: 32_000
  • normalize spectrogram: mean: -4.268, std: 4.569 (from esc-50)
  • spectrogram: n_fft: 2048, hop_length: 2048, power: 2.0
  • melscale: n_mels: 256, n_stft: 1025
  • dbscale: top_db: 80

See model implementation. Run in Google Colab:

from transformers import EfficientNetForImageClassification
import torch
import torchaudio
from torchvision import transforms
import requests
import torchaudio
import io


# download the audio file of a bird sound: Common Craw
url = "https://xeno-canto.org/704485/download"
response = requests.get(url)
audio, sample_rate = torchaudio.load(io.BytesIO(response.content))
print("Original shape and sample rate: ", audio.shape, sample_rate)
# crop to 5 seconds
audio = audio[:, : 5 * sample_rate]
# resample to 32kHz
resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
audio = resample(audio)
print("Resampled shape and sample rate: ", audio.shape, 32000)


CACHE_DIR = "../../data_birdset"  # Change this to your own cache directory

# Load the model
model = EfficientNetForImageClassification.from_pretrained(
    "DBD-research-group/EfficientNet-B1-BirdSet-XCL",
    num_channels=1,
    cache_dir=CACHE_DIR,
    ignore_mismatched_sizes=True,
)


class PowerToDB(torch.nn.Module):
    """
    A power spectrogram to decibel conversion layer. See birdset.datamodule.components.augmentations
    """

    def __init__(self, ref=1.0, amin=1e-10, top_db=80.0):
        super(PowerToDB, self).__init__()
        # Initialize parameters
        self.ref = ref
        self.amin = amin
        self.top_db = top_db

    def forward(self, S):
        # Convert S to a PyTorch tensor if it is not already
        S = torch.as_tensor(S, dtype=torch.float32)

        if self.amin <= 0:
            raise ValueError("amin must be strictly positive")

        if torch.is_complex(S):
            magnitude = S.abs()
        else:
            magnitude = S

        # Check if ref is a callable function or a scalar
        if callable(self.ref):
            ref_value = self.ref(magnitude)
        else:
            ref_value = torch.abs(torch.tensor(self.ref, dtype=S.dtype))

        # Compute the log spectrogram
        log_spec = 10.0 * torch.log10(
            torch.maximum(magnitude, torch.tensor(self.amin, device=magnitude.device))
        )
        log_spec -= 10.0 * torch.log10(
            torch.maximum(ref_value, torch.tensor(self.amin, device=magnitude.device))
        )

        # Apply top_db threshold if necessary
        if self.top_db is not None:
            if self.top_db < 0:
                raise ValueError("top_db must be non-negative")
            log_spec = torch.maximum(log_spec, log_spec.max() - self.top_db)

        return log_spec


def preprocess(audio, sample_rate_of_audio):
    """
    Preprocess the audio to the format that the model expects
    - Resample to 32kHz
    - Convert to melscale spectrogram n_fft: 2048, hop_length: 256, power: 2. melscale: n_mels: 256, n_stft: 1025
    - Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)

    """
    powerToDB = PowerToDB()
    # Resample to 32kHz
    resample = torchaudio.transforms.Resample(
        orig_freq=sample_rate_of_audio, new_freq=32000
    )
    audio = resample(audio)
    spectrogram = torchaudio.transforms.Spectrogram(
        n_fft=2048, hop_length=256, power=2.0
    )(audio)
    melspec = torchaudio.transforms.MelScale(n_mels=256, n_stft=1025)(spectrogram)
    dbscale = powerToDB(melspec)
    normalized_dbscale = transforms.Normalize((-4.268,), (4.569,))(dbscale)
    return normalized_dbscale

preprocessed_audio = preprocess(audio, sample_rate)
print("Preprocessed_audio shape:", preprocessed_audio.shape)


logits = model(preprocessed_audio.unsqueeze(0)).logits
print("Logits shape: ", logits.shape)

top5 = torch.topk(logits, 5)
print("Top 5 logits:", top5.values)
print("Top 5 predicted classes:")
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])

Model Source

Citation

Downloads last month
285
Safetensors
Model size
19M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including DBD-research-group/EfficientNet-B1-BirdSet-XCL