Spaces:
Runtime error
Runtime error
Commit
·
557fb53
1
Parent(s):
e82ec2b
Refactor config style and reorganize files
Browse files- .gitignore +1 -0
- TODO.md +5 -2
- environment.yml +5 -0
- models/audio_spectrogram_transformer.py +117 -76
- models/config/decision_tree.yaml +47 -0
- models/config/train.yaml +5 -5
- models/config/train_local.yaml +47 -36
- models/decision_tree.py +124 -37
- models/residual.py +82 -86
- models/training_environment.py +90 -0
- models/utils.py +47 -20
- models/wav2vec2.py +84 -0
- preprocessing/dataset.py +230 -198
- preprocessing/pipelines.py +56 -42
- preprocessing/preprocess.py +66 -44
- tests.py +0 -22
- tests/test_datasets.py +17 -0
- tests/test_pipelines.py +13 -0
- tests/utils.py +7 -0
- train.py +9 -176
.gitignore
CHANGED
|
@@ -9,3 +9,4 @@ lightning_logs
|
|
| 9 |
.lr_find_*
|
| 10 |
.cache
|
| 11 |
.vscode
|
|
|
|
|
|
| 9 |
.lr_find_*
|
| 10 |
.cache
|
| 11 |
.vscode
|
| 12 |
+
models/weights/ast
|
TODO.md
CHANGED
|
@@ -6,10 +6,13 @@
|
|
| 6 |
- Create an attention-based network
|
| 7 |
- ✅ Increase parameter count in network
|
| 8 |
- Verify that labels really match what is on the music4dance site
|
| 9 |
-
- Read the Medium series about audio DL
|
| 10 |
- double check \_rectify_duration
|
| 11 |
- ✅ Filter out songs that have only one vote
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
## Notes
|
| 14 |
|
| 15 |
-
2xM60 insufficient memory.
|
|
|
|
| 6 |
- Create an attention-based network
|
| 7 |
- ✅ Increase parameter count in network
|
| 8 |
- Verify that labels really match what is on the music4dance site
|
| 9 |
+
- ✅ Read the Medium series about audio DL
|
| 10 |
- double check \_rectify_duration
|
| 11 |
- ✅ Filter out songs that have only one vote
|
| 12 |
+
- ✅ Download songs from [Best Ballroom](https://www.youtube.com/channel/UC0bYSnzAFMwPiEjmVsrvmRg)
|
| 13 |
+
|
| 14 |
+
- ✅ fix nan values
|
| 15 |
|
| 16 |
## Notes
|
| 17 |
|
| 18 |
+
2xM60 insufficient memory for the AST.
|
environment.yml
CHANGED
|
@@ -23,6 +23,11 @@ dependencies:
|
|
| 23 |
- scikit-learn
|
| 24 |
- tensorboard
|
| 25 |
- transformers
|
|
|
|
|
|
|
|
|
|
| 26 |
- pip:
|
| 27 |
- evaluate
|
| 28 |
- wakepy
|
|
|
|
|
|
|
|
|
| 23 |
- scikit-learn
|
| 24 |
- tensorboard
|
| 25 |
- transformers
|
| 26 |
+
- accelerate
|
| 27 |
+
- pytest
|
| 28 |
+
|
| 29 |
- pip:
|
| 30 |
- evaluate
|
| 31 |
- wakepy
|
| 32 |
+
- soundfile
|
| 33 |
+
- youtube_dl
|
models/audio_spectrogram_transformer.py
CHANGED
|
@@ -1,93 +1,138 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
from torch import nn
|
| 4 |
-
from
|
| 5 |
-
import
|
| 6 |
-
import numpy as np
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
super().__init__(*args, **kwargs)
|
| 16 |
id2label, label2id = get_id_label_mapping(labels)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
ignore_mismatched_sizes=True
|
| 26 |
-
)
|
| 27 |
-
self.sample_rate = sample_rate
|
| 28 |
-
|
| 29 |
-
self.bpm_model = nn.Sequential(
|
| 30 |
-
nn.Linear(len(labels), 100),
|
| 31 |
-
nn.Linear(100, 50)
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
out_dim = 50 # TODO: Calculate output dimension
|
| 35 |
-
self.classifier = nn.Sequential(
|
| 36 |
-
nn.Linear(out_dim, 100),
|
| 37 |
-
nn.Linear(100, len(labels))
|
| 38 |
)
|
| 39 |
-
|
| 40 |
-
def vectorize_bpm(self, waveform):
|
| 41 |
-
pass
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def forward(self, audio):
|
| 45 |
-
|
| 46 |
-
bpm_vector = self.vectorize_bpm(audio)
|
| 47 |
-
bpm_out = self.bpm_model(bpm_vector)
|
| 48 |
-
|
| 49 |
-
spectrogram = self.ast_feature_extractor(audio)
|
| 50 |
-
ast_out = self.ast_model(spectrogram)
|
| 51 |
-
|
| 52 |
-
# Late fusion
|
| 53 |
-
z = torch.cat([ast_out, bpm_out]) # Which dimension?
|
| 54 |
-
return self.classifier(z)
|
| 55 |
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
def compute_metrics(eval_pred):
|
| 58 |
-
predictions = np.argmax(eval_pred.predictions, axis=1)
|
| 59 |
-
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 77 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
| 78 |
-
preprocess_waveform = lambda wf
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
model = AutoModelForAudioClassification.from_pretrained(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
).to(
|
| 89 |
training_args = TrainingArguments(
|
| 90 |
-
output_dir=
|
| 91 |
evaluation_strategy="epoch",
|
| 92 |
save_strategy="epoch",
|
| 93 |
learning_rate=5e-5,
|
|
@@ -100,7 +145,7 @@ def train(
|
|
| 100 |
load_best_model_at_end=True,
|
| 101 |
metric_for_best_model="accuracy",
|
| 102 |
push_to_hub=False,
|
| 103 |
-
use_mps_device=
|
| 104 |
)
|
| 105 |
|
| 106 |
trainer = Trainer(
|
|
@@ -109,11 +154,7 @@ def train(
|
|
| 109 |
train_dataset=train_ds,
|
| 110 |
eval_dataset=test_ds,
|
| 111 |
tokenizer=feature_extractor,
|
| 112 |
-
compute_metrics=
|
| 113 |
)
|
| 114 |
trainer.train()
|
| 115 |
return model
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoFeatureExtractor,
|
| 6 |
+
AutoModelForAudioClassification,
|
| 7 |
+
TrainingArguments,
|
| 8 |
+
Trainer,
|
| 9 |
+
ASTConfig,
|
| 10 |
+
ASTFeatureExtractor,
|
| 11 |
+
ASTForAudioClassification,
|
| 12 |
+
)
|
| 13 |
import torch
|
| 14 |
from torch import nn
|
| 15 |
+
from models.training_environment import TrainingEnvironment
|
| 16 |
+
from preprocessing.pipelines import WaveformTrainingPipeline
|
|
|
|
| 17 |
|
| 18 |
+
from preprocessing.dataset import (
|
| 19 |
+
DanceDataModule,
|
| 20 |
+
HuggingFaceDatasetWrapper,
|
| 21 |
+
get_datasets,
|
| 22 |
+
)
|
| 23 |
+
from preprocessing.dataset import get_music4dance_examples
|
| 24 |
+
from .utils import get_id_label_mapping, compute_hf_metrics
|
| 25 |
|
| 26 |
+
import pytorch_lightning as pl
|
| 27 |
+
from pytorch_lightning import callbacks as cb
|
| 28 |
|
| 29 |
+
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 30 |
|
| 31 |
|
| 32 |
+
class AST(nn.Module):
|
| 33 |
+
def __init__(self, labels, *args, **kwargs) -> None:
|
| 34 |
super().__init__(*args, **kwargs)
|
| 35 |
id2label, label2id = get_id_label_mapping(labels)
|
| 36 |
+
config = ASTConfig(
|
| 37 |
+
hidden_size=300,
|
| 38 |
+
num_attention_heads=5,
|
| 39 |
+
num_hidden_layers=3,
|
| 40 |
+
id2label=id2label,
|
| 41 |
+
label2id=label2id,
|
| 42 |
+
num_labels=len(label2id),
|
| 43 |
+
ignore_mismatched_sizes=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
)
|
| 45 |
+
self.model = ASTForAudioClassification(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.model(x).logits
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
class ASTExtractorWrapper:
|
| 52 |
+
def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
|
| 53 |
+
self.extractor = ASTFeatureExtractor()
|
| 54 |
+
self.sampling_rate = sampling_rate
|
| 55 |
+
self.return_tensors = return_tensors
|
| 56 |
+
self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
|
| 57 |
|
| 58 |
+
def __call__(self, x) -> Any:
|
| 59 |
+
x = self.waveform_pipeline(x)
|
| 60 |
+
device = x.device
|
| 61 |
+
x = x.squeeze(0).numpy()
|
| 62 |
+
x = self.extractor(
|
| 63 |
+
x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate
|
| 64 |
+
)
|
| 65 |
+
return x["input_values"].squeeze(0).to(device)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def train_lightning_ast(config: dict):
|
| 69 |
+
"""
|
| 70 |
+
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
| 71 |
+
"""
|
| 72 |
+
TARGET_CLASSES = config["dance_ids"]
|
| 73 |
+
DEVICE = config["device"]
|
| 74 |
+
SEED = config["seed"]
|
| 75 |
+
pl.seed_everything(SEED, workers=True)
|
| 76 |
+
feature_extractor = ASTExtractorWrapper()
|
| 77 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
| 78 |
+
data = DanceDataModule(
|
| 79 |
+
dataset,
|
| 80 |
+
target_classes=TARGET_CLASSES,
|
| 81 |
+
**config["data_module"],
|
| 82 |
+
)
|
| 83 |
|
| 84 |
+
model = AST(TARGET_CLASSES).to(DEVICE)
|
| 85 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
| 86 |
+
criterion = nn.CrossEntropyLoss(
|
| 87 |
+
label_weights
|
| 88 |
+
) # LabelWeightedBCELoss(label_weights)
|
| 89 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
| 90 |
+
callbacks = [
|
| 91 |
+
# cb.LearningRateFinder(update_attr=True),
|
| 92 |
+
cb.EarlyStopping("val/loss", patience=5),
|
| 93 |
+
cb.RichProgressBar(),
|
| 94 |
+
]
|
| 95 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
| 96 |
+
trainer.fit(train_env, datamodule=data)
|
| 97 |
+
trainer.test(train_env, datamodule=data)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def train_huggingface_ast(config: dict):
|
| 101 |
+
TARGET_CLASSES = config["dance_ids"]
|
| 102 |
+
DEVICE = config["device"]
|
| 103 |
+
SEED = config["seed"]
|
| 104 |
+
OUTPUT_DIR = "models/weights/ast"
|
| 105 |
+
batch_size = config["data_module"]["batch_size"]
|
| 106 |
+
epochs = config["data_module"]["min_epochs"]
|
| 107 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
| 108 |
+
pl.seed_everything(SEED, workers=True)
|
| 109 |
+
dataset = get_datasets(config["datasets"])
|
| 110 |
+
hf_dataset = HuggingFaceDatasetWrapper(dataset)
|
| 111 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
| 112 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 113 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
| 114 |
+
preprocess_waveform = lambda wf: feature_extractor(
|
| 115 |
+
wf,
|
| 116 |
+
sampling_rate=train_ds.resample_frequency,
|
| 117 |
+
# padding="max_length",
|
| 118 |
+
# return_tensors="pt",
|
| 119 |
+
)
|
| 120 |
+
hf_dataset.append_to_pipeline(preprocess_waveform)
|
| 121 |
+
test_proportion = config["data_module"]["test_proportion"]
|
| 122 |
+
train_proporition = 1 - test_proportion
|
| 123 |
+
train_ds, test_ds = torch.utils.data.random_split(
|
| 124 |
+
hf_dataset, [train_proporition, test_proportion]
|
| 125 |
+
)
|
| 126 |
|
| 127 |
model = AutoModelForAudioClassification.from_pretrained(
|
| 128 |
+
model_checkpoint,
|
| 129 |
+
num_labels=len(TARGET_CLASSES),
|
| 130 |
+
label2id=label2id,
|
| 131 |
+
id2label=id2label,
|
| 132 |
+
ignore_mismatched_sizes=True,
|
| 133 |
+
).to(DEVICE)
|
| 134 |
training_args = TrainingArguments(
|
| 135 |
+
output_dir=OUTPUT_DIR,
|
| 136 |
evaluation_strategy="epoch",
|
| 137 |
save_strategy="epoch",
|
| 138 |
learning_rate=5e-5,
|
|
|
|
| 145 |
load_best_model_at_end=True,
|
| 146 |
metric_for_best_model="accuracy",
|
| 147 |
push_to_hub=False,
|
| 148 |
+
use_mps_device=DEVICE == "mps",
|
| 149 |
)
|
| 150 |
|
| 151 |
trainer = Trainer(
|
|
|
|
| 154 |
train_dataset=train_ds,
|
| 155 |
eval_dataset=test_ds,
|
| 156 |
tokenizer=feature_extractor,
|
| 157 |
+
compute_metrics=compute_hf_metrics,
|
| 158 |
)
|
| 159 |
trainer.train()
|
| 160 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
models/config/decision_tree.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
global:
|
| 2 |
+
id: decision_tree
|
| 3 |
+
device: mps
|
| 4 |
+
seed: 42
|
| 5 |
+
dance_ids:
|
| 6 |
+
- ATN
|
| 7 |
+
- BCH
|
| 8 |
+
- CHA
|
| 9 |
+
- ECS
|
| 10 |
+
- HST
|
| 11 |
+
- JIV
|
| 12 |
+
- QST
|
| 13 |
+
- RMB
|
| 14 |
+
- SFT
|
| 15 |
+
- SLS
|
| 16 |
+
- SMB
|
| 17 |
+
- SWZ
|
| 18 |
+
- TGO
|
| 19 |
+
- VWZ
|
| 20 |
+
- WCS
|
| 21 |
+
data_module:
|
| 22 |
+
song_data_path: data/songs_cleaned.csv
|
| 23 |
+
song_audio_path: data/samples
|
| 24 |
+
batch_size: 32
|
| 25 |
+
num_workers: 7
|
| 26 |
+
min_votes: 1
|
| 27 |
+
dataset_kwargs:
|
| 28 |
+
audio_window_duration: 6
|
| 29 |
+
audio_window_jitter: 1.5
|
| 30 |
+
audio_pipeline_kwargs:
|
| 31 |
+
mask_count: 0 # Don't mask the data
|
| 32 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
| 33 |
+
freq_mask_size: 10
|
| 34 |
+
time_mask_size: 80
|
| 35 |
+
|
| 36 |
+
trainer:
|
| 37 |
+
log_every_n_steps: 15
|
| 38 |
+
accelerator: gpu
|
| 39 |
+
max_epochs: 50
|
| 40 |
+
min_epochs: 5
|
| 41 |
+
fast_dev_run: False
|
| 42 |
+
# gradient_clip_val: 0.5
|
| 43 |
+
# overfit_batches: 1
|
| 44 |
+
training_environment:
|
| 45 |
+
learning_rate: 0.00053
|
| 46 |
+
model:
|
| 47 |
+
n_channels: 128
|
models/config/train.yaml
CHANGED
|
@@ -27,11 +27,11 @@ data_module:
|
|
| 27 |
dataset_kwargs:
|
| 28 |
audio_window_duration: 6
|
| 29 |
audio_window_jitter: 1.5
|
| 30 |
-
audio_pipeline_kwargs:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
trainer:
|
| 37 |
log_every_n_steps: 15
|
|
|
|
| 27 |
dataset_kwargs:
|
| 28 |
audio_window_duration: 6
|
| 29 |
audio_window_jitter: 1.5
|
| 30 |
+
# audio_pipeline_kwargs:
|
| 31 |
+
# mask_count: 0 # Don't mask the data
|
| 32 |
+
# snr_mean: 15.0 # Pretty much eliminate the noise
|
| 33 |
+
# freq_mask_size: 10
|
| 34 |
+
# time_mask_size: 80
|
| 35 |
|
| 36 |
trainer:
|
| 37 |
log_every_n_steps: 15
|
models/config/train_local.yaml
CHANGED
|
@@ -1,47 +1,58 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
- VWZ
|
| 20 |
-
- WCS
|
| 21 |
data_module:
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
trainer:
|
| 37 |
log_every_n_steps: 15
|
| 38 |
accelerator: gpu
|
| 39 |
max_epochs: 50
|
| 40 |
-
min_epochs:
|
| 41 |
fast_dev_run: False
|
| 42 |
# gradient_clip_val: 0.5
|
| 43 |
# overfit_batches: 1
|
|
|
|
| 44 |
training_environment:
|
| 45 |
learning_rate: 0.00053
|
| 46 |
-
|
| 47 |
-
n_channels: 128
|
|
|
|
| 1 |
+
training_fn: audio_spectrogram_transformer.train_lightning_ast
|
| 2 |
+
device: mps
|
| 3 |
+
seed: 42
|
| 4 |
+
dance_ids: &dance_ids
|
| 5 |
+
- BCH
|
| 6 |
+
- CHA
|
| 7 |
+
- JIV
|
| 8 |
+
- ECS
|
| 9 |
+
- QST
|
| 10 |
+
- RMB
|
| 11 |
+
- SFT
|
| 12 |
+
- SLS
|
| 13 |
+
- SMB
|
| 14 |
+
- SWZ
|
| 15 |
+
- TGO
|
| 16 |
+
- VWZ
|
| 17 |
+
- WCS
|
| 18 |
+
|
|
|
|
|
|
|
| 19 |
data_module:
|
| 20 |
+
batch_size: 64
|
| 21 |
+
num_workers: 10
|
| 22 |
+
test_proportion: 0.2
|
| 23 |
+
|
| 24 |
+
datasets:
|
| 25 |
+
preprocessing.dataset.BestBallroomDataset:
|
| 26 |
+
audio_dir: data/ballroom-songs
|
| 27 |
+
class_list: *dance_ids
|
| 28 |
+
audio_window_jitter: 0.7
|
| 29 |
+
|
| 30 |
+
preprocessing.dataset.Music4DanceDataset:
|
| 31 |
+
song_data_path: data/songs_cleaned.csv
|
| 32 |
+
song_audio_path: data/samples # data/samples
|
| 33 |
+
class_list: *dance_ids
|
| 34 |
+
multi_label: False
|
| 35 |
+
min_votes: 1
|
| 36 |
+
audio_window_jitter: 0.7
|
| 37 |
+
|
| 38 |
+
model:
|
| 39 |
+
n_channels: 128
|
| 40 |
+
|
| 41 |
+
feature_extractor:
|
| 42 |
+
mask_count: 0 # Don't mask the data
|
| 43 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
| 44 |
+
freq_mask_size: 10
|
| 45 |
+
time_mask_size: 80
|
| 46 |
|
| 47 |
trainer:
|
| 48 |
log_every_n_steps: 15
|
| 49 |
accelerator: gpu
|
| 50 |
max_epochs: 50
|
| 51 |
+
min_epochs: 7
|
| 52 |
fast_dev_run: False
|
| 53 |
# gradient_clip_val: 0.5
|
| 54 |
# overfit_batches: 1
|
| 55 |
+
|
| 56 |
training_environment:
|
| 57 |
learning_rate: 0.00053
|
| 58 |
+
log_spectrograms: False
|
|
|
models/decision_tree.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
| 2 |
import pandas as pd
|
| 3 |
from torch import nn
|
|
@@ -5,8 +6,14 @@ import torch
|
|
| 5 |
from typing import Iterator
|
| 6 |
import numpy as np
|
| 7 |
import json
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
import librosa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
| 12 |
dance_info_df = pd.read_csv(
|
|
@@ -24,9 +31,8 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
| 24 |
- BPM
|
| 25 |
"""
|
| 26 |
|
| 27 |
-
def __init__(self, device="cpu", lr=1e-4,
|
| 28 |
self.device = device
|
| 29 |
-
self.epochs = epochs
|
| 30 |
self.verbose = verbose
|
| 31 |
self.lr = lr
|
| 32 |
self.classifiers = {}
|
|
@@ -44,41 +50,40 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
| 44 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
| 45 |
y: (batch_size, n_classes)
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
self.optimizers[dance] = torch.optim.Adam(
|
| 63 |
-
classifier.parameters(), lr=self.lr
|
| 64 |
-
)
|
| 65 |
-
models = [
|
| 66 |
-
(dance, model, self.optimizers[dance])
|
| 67 |
-
for dance, model in self.classifiers.items()
|
| 68 |
-
if dance in matching_dances
|
| 69 |
-
]
|
| 70 |
-
for model_i, (dance, model, opt) in enumerate(models):
|
| 71 |
-
opt.zero_grad()
|
| 72 |
-
output = model(spec)
|
| 73 |
-
target = torch.tensor([float(dance == label)], device=self.device)
|
| 74 |
-
loss = self.criterion(output, target)
|
| 75 |
-
epoch_loss += loss.item()
|
| 76 |
-
pred_count += 1
|
| 77 |
-
loss.backward()
|
| 78 |
-
opt.step()
|
| 79 |
-
progress_bar.set_description(
|
| 80 |
-
f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}"
|
| 81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def predict(self, x) -> list[str]:
|
| 84 |
results = []
|
|
@@ -90,6 +95,52 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
| 90 |
results.append(matching_dances[dance_i])
|
| 91 |
return results
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
class DanceCNN(nn.Module):
|
| 95 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
|
@@ -136,7 +187,6 @@ def features_from_path(
|
|
| 136 |
num_frames = audio_window_duration * sr
|
| 137 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
| 138 |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
|
| 139 |
-
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
|
| 140 |
spec_normalized = (spec - spec.mean()) / spec.std()
|
| 141 |
spec_padded = librosa.util.fix_length(
|
| 142 |
spec_normalized, size=sr * audio_duration, axis=1
|
|
@@ -145,3 +195,40 @@ def features_from_path(
|
|
| 145 |
for i in range(audio_duration // audio_window_duration):
|
| 146 |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
|
| 147 |
yield (spec_window, tempo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
| 3 |
import pandas as pd
|
| 4 |
from torch import nn
|
|
|
|
| 6 |
from typing import Iterator
|
| 7 |
import numpy as np
|
| 8 |
import json
|
| 9 |
+
from torch.utils.data import random_split
|
| 10 |
from tqdm import tqdm
|
| 11 |
import librosa
|
| 12 |
+
from joblib import dump, load
|
| 13 |
+
from os import path
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
from preprocessing.dataset import get_music4dance_examples
|
| 17 |
|
| 18 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
| 19 |
dance_info_df = pd.read_csv(
|
|
|
|
| 31 |
- BPM
|
| 32 |
"""
|
| 33 |
|
| 34 |
+
def __init__(self, device="cpu", lr=1e-4, verbose=True) -> None:
|
| 35 |
self.device = device
|
|
|
|
| 36 |
self.verbose = verbose
|
| 37 |
self.lr = lr
|
| 38 |
self.classifiers = {}
|
|
|
|
| 50 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
| 51 |
y: (batch_size, n_classes)
|
| 52 |
"""
|
| 53 |
+
epoch_loss = 0
|
| 54 |
+
pred_count = 0
|
| 55 |
+
data_loader = zip(x, y)
|
| 56 |
+
if self.verbose:
|
| 57 |
+
data_loader = tqdm(data_loader, total=len(y))
|
| 58 |
+
for (spec, bpm), label in data_loader:
|
| 59 |
+
# find all models that are in the bpm range
|
| 60 |
+
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
| 61 |
+
spec = torch.from_numpy(spec).to(self.device)
|
| 62 |
+
for dance in matching_dances:
|
| 63 |
+
if dance not in self.classifiers or dance not in self.optimizers:
|
| 64 |
+
classifier = DanceCNN().to(self.device)
|
| 65 |
+
self.classifiers[dance] = classifier
|
| 66 |
+
self.optimizers[dance] = torch.optim.Adam(
|
| 67 |
+
classifier.parameters(), lr=self.lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
+
models = [
|
| 70 |
+
(dance, model, self.optimizers[dance])
|
| 71 |
+
for dance, model in self.classifiers.items()
|
| 72 |
+
if dance in matching_dances
|
| 73 |
+
]
|
| 74 |
+
for model_i, (dance, model, opt) in enumerate(models, start=1):
|
| 75 |
+
opt.zero_grad()
|
| 76 |
+
output = model(spec)
|
| 77 |
+
target = torch.tensor([float(dance == label)], device=self.device)
|
| 78 |
+
loss = self.criterion(output, target)
|
| 79 |
+
epoch_loss += loss.item()
|
| 80 |
+
pred_count += 1
|
| 81 |
+
loss.backward()
|
| 82 |
+
if self.verbose:
|
| 83 |
+
data_loader.set_description(
|
| 84 |
+
f"model: {model_i}/{len(models)}, loss: {loss.item()}"
|
| 85 |
+
)
|
| 86 |
+
opt.step()
|
| 87 |
|
| 88 |
def predict(self, x) -> list[str]:
|
| 89 |
results = []
|
|
|
|
| 95 |
results.append(matching_dances[dance_i])
|
| 96 |
return results
|
| 97 |
|
| 98 |
+
def save(self, folder: str):
|
| 99 |
+
# Create a folder
|
| 100 |
+
classifier_path = path.join(folder, "classifier")
|
| 101 |
+
os.makedirs(classifier_path, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
# Swap out model reference
|
| 104 |
+
classifiers = self.classifiers
|
| 105 |
+
optimizers = self.optimizers
|
| 106 |
+
criterion = self.criterion
|
| 107 |
+
|
| 108 |
+
self.classifiers = None
|
| 109 |
+
self.optimizers = None
|
| 110 |
+
self.criterion = None
|
| 111 |
+
|
| 112 |
+
# Save the Pth models
|
| 113 |
+
for dance, classifier in classifiers.items():
|
| 114 |
+
torch.save(
|
| 115 |
+
classifier.state_dict(), path.join(classifier_path, dance + ".pth")
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Save the Sklearn model
|
| 119 |
+
dump(path.join(folder, "sklearn.joblib"))
|
| 120 |
+
|
| 121 |
+
# Reload values
|
| 122 |
+
self.classifiers = classifiers
|
| 123 |
+
self.optimizers = optimizers
|
| 124 |
+
self.criterion = criterion
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def from_config(folder: str, device="cpu") -> "DanceTreeClassifier":
|
| 128 |
+
# load in weights
|
| 129 |
+
model_paths = (
|
| 130 |
+
p for p in os.listdir(path.join(folder, "classifier")) if p.endswith("pth")
|
| 131 |
+
)
|
| 132 |
+
classifiers = {}
|
| 133 |
+
for model_path in model_paths:
|
| 134 |
+
dance = model_path.split(".")[0]
|
| 135 |
+
model = DanceCNN().to(device)
|
| 136 |
+
model.load_state_dict(
|
| 137 |
+
torch.load(path.join(folder, "classifier", model_path))
|
| 138 |
+
)
|
| 139 |
+
classifiers[dance] = model
|
| 140 |
+
wrapper = load(path.join(folder, "sklearn.joblib"))
|
| 141 |
+
wrapper.classifiers = classifiers
|
| 142 |
+
return wrapper
|
| 143 |
+
|
| 144 |
|
| 145 |
class DanceCNN(nn.Module):
|
| 146 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
|
|
|
| 187 |
num_frames = audio_window_duration * sr
|
| 188 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
| 189 |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
|
|
|
|
| 190 |
spec_normalized = (spec - spec.mean()) / spec.std()
|
| 191 |
spec_padded = librosa.util.fix_length(
|
| 192 |
spec_normalized, size=sr * audio_duration, axis=1
|
|
|
|
| 195 |
for i in range(audio_duration // audio_window_duration):
|
| 196 |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
|
| 197 |
yield (spec_window, tempo)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def train_decision_tree(config: dict):
|
| 201 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
| 202 |
+
DEVICE = config["global"]["device"]
|
| 203 |
+
SEED = config["global"]["seed"]
|
| 204 |
+
SEED = config["global"]["seed"]
|
| 205 |
+
EPOCHS = config["trainer"]["min_epochs"]
|
| 206 |
+
song_data_path = config["data_module"]["song_data_path"]
|
| 207 |
+
song_audio_path = config["data_module"]["song_audio_path"]
|
| 208 |
+
pl.seed_everything(SEED, workers=True)
|
| 209 |
+
|
| 210 |
+
df = pd.read_csv(song_data_path)
|
| 211 |
+
x, y = get_music4dance_examples(
|
| 212 |
+
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
| 213 |
+
)
|
| 214 |
+
# Convert y back to string classes
|
| 215 |
+
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
| 216 |
+
train_i, test_i = random_split(
|
| 217 |
+
np.arange(len(x)), [0.1, 0.9]
|
| 218 |
+
) # Temporary to test efficacy
|
| 219 |
+
train_paths, train_y = x[train_i], y[train_i]
|
| 220 |
+
model = DanceTreeClassifier(device=DEVICE)
|
| 221 |
+
for epoch in tqdm(range(1, EPOCHS + 1)):
|
| 222 |
+
# Shuffle the data
|
| 223 |
+
i = np.arange(len(train_paths))
|
| 224 |
+
np.random.shuffle(i)
|
| 225 |
+
train_paths = train_paths[i]
|
| 226 |
+
train_y = train_y[i]
|
| 227 |
+
train_x = features_from_path(train_paths)
|
| 228 |
+
model.fit(train_x, train_y)
|
| 229 |
+
|
| 230 |
+
# evaluate the model
|
| 231 |
+
preds = model.predict(x[test_i])
|
| 232 |
+
accuracy = (preds == y[test_i]).mean()
|
| 233 |
+
print(f"{accuracy=}")
|
| 234 |
+
model.save("models/weights/decision_tree")
|
models/residual.py
CHANGED
|
@@ -1,18 +1,25 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
-
import pytorch_lightning as pl
|
| 6 |
import numpy as np
|
| 7 |
import torchaudio
|
| 8 |
import yaml
|
| 9 |
-
from .
|
| 10 |
-
from preprocessing.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
| 13 |
|
|
|
|
| 14 |
class ResidualDancer(nn.Module):
|
| 15 |
-
def __init__(self,n_channels=128, n_classes=50):
|
| 16 |
super().__init__()
|
| 17 |
|
| 18 |
self.n_channels = n_channels
|
|
@@ -25,17 +32,17 @@ class ResidualDancer(nn.Module):
|
|
| 25 |
self.res_layers = nn.Sequential(
|
| 26 |
ResBlock(1, n_channels, stride=2),
|
| 27 |
ResBlock(n_channels, n_channels, stride=2),
|
| 28 |
-
ResBlock(n_channels, n_channels*2, stride=2),
|
| 29 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
| 30 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
| 31 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
| 32 |
-
ResBlock(n_channels*2, n_channels*4, stride=2)
|
| 33 |
)
|
| 34 |
|
| 35 |
# Dense
|
| 36 |
-
self.dense1 = nn.Linear(n_channels*4, n_channels*4)
|
| 37 |
-
self.bn = nn.BatchNorm1d(n_channels*4)
|
| 38 |
-
self.dense2 = nn.Linear(n_channels*4, n_classes)
|
| 39 |
self.dropout = nn.Dropout(0.2)
|
| 40 |
|
| 41 |
def forward(self, x):
|
|
@@ -56,24 +63,34 @@ class ResidualDancer(nn.Module):
|
|
| 56 |
x = F.relu(x)
|
| 57 |
x = self.dropout(x)
|
| 58 |
x = self.dense2(x)
|
| 59 |
-
x = nn.Sigmoid()(x)
|
| 60 |
|
| 61 |
return x
|
| 62 |
-
|
| 63 |
|
| 64 |
class ResBlock(nn.Module):
|
| 65 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
| 66 |
super().__init__()
|
| 67 |
# convolution
|
| 68 |
-
self.conv_1 = nn.Conv2d(
|
|
|
|
|
|
|
| 69 |
self.bn_1 = nn.BatchNorm2d(output_channels)
|
| 70 |
-
self.conv_2 = nn.Conv2d(
|
|
|
|
|
|
|
| 71 |
self.bn_2 = nn.BatchNorm2d(output_channels)
|
| 72 |
|
| 73 |
# residual
|
| 74 |
self.diff = False
|
| 75 |
if (stride != 1) or (input_channels != output_channels):
|
| 76 |
-
self.conv_3 = nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
self.bn_3 = nn.BatchNorm2d(output_channels)
|
| 78 |
self.diff = True
|
| 79 |
self.relu = nn.ReLU()
|
|
@@ -89,79 +106,31 @@ class ResBlock(nn.Module):
|
|
| 89 |
out = self.relu(out)
|
| 90 |
return out
|
| 91 |
|
| 92 |
-
class TrainingEnvironment(pl.LightningModule):
|
| 93 |
-
|
| 94 |
-
def __init__(self, model: nn.Module, criterion: nn.Module, config:dict, learning_rate=1e-4, *args, **kwargs):
|
| 95 |
-
super().__init__(*args, **kwargs)
|
| 96 |
-
self.model = model
|
| 97 |
-
self.criterion = criterion
|
| 98 |
-
self.learning_rate = learning_rate
|
| 99 |
-
self.config=config
|
| 100 |
-
self.save_hyperparameters({
|
| 101 |
-
"model": type(model).__name__,
|
| 102 |
-
"loss": type(criterion).__name__,
|
| 103 |
-
"config": config,
|
| 104 |
-
**kwargs
|
| 105 |
-
})
|
| 106 |
-
|
| 107 |
-
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
| 108 |
-
features, labels = batch
|
| 109 |
-
outputs = self.model(features)
|
| 110 |
-
loss = self.criterion(outputs, labels)
|
| 111 |
-
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
| 112 |
-
self.log_dict(metrics, prog_bar=True)
|
| 113 |
-
# Log spectrograms
|
| 114 |
-
if batch_index % 100 == 0:
|
| 115 |
-
tensorboard = self.logger.experiment
|
| 116 |
-
img_index = torch.randint(0, len(features), (1,)).item()
|
| 117 |
-
img = features[img_index][0]
|
| 118 |
-
img = (img - img.min()) / (img.max() - img.min())
|
| 119 |
-
tensorboard.add_image(f"batch: {batch_index}, element: {img_index}", img, 0, dataformats='HW')
|
| 120 |
-
return loss
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
| 124 |
-
x, y = batch
|
| 125 |
-
preds = self.model(x)
|
| 126 |
-
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
| 127 |
-
metrics["val/loss"] = self.criterion(preds, y)
|
| 128 |
-
self.log_dict(metrics,prog_bar=True)
|
| 129 |
-
|
| 130 |
-
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
| 131 |
-
x, y = batch
|
| 132 |
-
preds = self.model(x)
|
| 133 |
-
self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
|
| 134 |
-
|
| 135 |
-
def configure_optimizers(self):
|
| 136 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 137 |
-
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
| 138 |
-
return [optimizer]
|
| 139 |
-
|
| 140 |
-
|
| 141 |
|
| 142 |
class DancePredictor:
|
| 143 |
def __init__(
|
| 144 |
-
self,
|
| 145 |
-
weight_path:str,
|
| 146 |
-
labels:list[str],
|
| 147 |
-
expected_duration=6,
|
| 148 |
threshold=0.5,
|
| 149 |
resample_frequency=16000,
|
| 150 |
-
device="cpu"
|
| 151 |
-
|
| 152 |
super().__init__()
|
| 153 |
-
|
| 154 |
self.expected_duration = expected_duration
|
| 155 |
self.threshold = threshold
|
| 156 |
self.resample_frequency = resample_frequency
|
| 157 |
-
self.preprocess_waveform = WaveformPreprocessing(
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
self.labels = np.array(labels)
|
| 160 |
self.device = device
|
| 161 |
self.model = self.get_model(weight_path)
|
| 162 |
|
| 163 |
-
|
| 164 |
-
def get_model(self, weight_path:str) -> nn.Module:
|
| 165 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
| 166 |
model = ResidualDancer(n_classes=len(self.labels))
|
| 167 |
for key in list(weights):
|
|
@@ -170,21 +139,25 @@ class DancePredictor:
|
|
| 170 |
return model.to(self.device).eval()
|
| 171 |
|
| 172 |
@classmethod
|
| 173 |
-
def from_config(cls, config_path:str) -> "DancePredictor":
|
| 174 |
with open(config_path, "r") as f:
|
| 175 |
config = yaml.safe_load(f)
|
| 176 |
return DancePredictor(**config)
|
| 177 |
|
| 178 |
@torch.no_grad()
|
| 179 |
-
def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
|
| 180 |
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
| 181 |
-
waveform = waveform.transpose(1,0)
|
| 182 |
elif len(waveform.shape) == 1:
|
| 183 |
waveform = np.expand_dims(waveform, 0)
|
| 184 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
| 185 |
-
waveform = torchaudio.functional.apply_codec(
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
waveform = torchaudio.functional.resample(
|
|
|
|
|
|
|
| 188 |
waveform = self.preprocess_waveform(waveform)
|
| 189 |
spectrogram = self.audio_to_spectrogram(waveform)
|
| 190 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
|
@@ -194,8 +167,31 @@ class DancePredictor:
|
|
| 194 |
result_mask = results > self.threshold
|
| 195 |
probs = results[result_mask]
|
| 196 |
dances = self.labels[result_mask]
|
| 197 |
-
|
| 198 |
-
return {dance:float(prob) for dance, prob in zip(dances, probs)}
|
| 199 |
-
|
| 200 |
-
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
from pytorch_lightning import callbacks as cb
|
| 3 |
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import torchaudio
|
| 10 |
import yaml
|
| 11 |
+
from models.training_environment import TrainingEnvironment
|
| 12 |
+
from preprocessing.dataset import DanceDataModule, get_datasets
|
| 13 |
+
from preprocessing.pipelines import (
|
| 14 |
+
SpectrogramTrainingPipeline,
|
| 15 |
+
WaveformPreprocessing,
|
| 16 |
+
)
|
| 17 |
|
| 18 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
| 19 |
|
| 20 |
+
|
| 21 |
class ResidualDancer(nn.Module):
|
| 22 |
+
def __init__(self, n_channels=128, n_classes=50):
|
| 23 |
super().__init__()
|
| 24 |
|
| 25 |
self.n_channels = n_channels
|
|
|
|
| 32 |
self.res_layers = nn.Sequential(
|
| 33 |
ResBlock(1, n_channels, stride=2),
|
| 34 |
ResBlock(n_channels, n_channels, stride=2),
|
| 35 |
+
ResBlock(n_channels, n_channels * 2, stride=2),
|
| 36 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
| 37 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
| 38 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
| 39 |
+
ResBlock(n_channels * 2, n_channels * 4, stride=2),
|
| 40 |
)
|
| 41 |
|
| 42 |
# Dense
|
| 43 |
+
self.dense1 = nn.Linear(n_channels * 4, n_channels * 4)
|
| 44 |
+
self.bn = nn.BatchNorm1d(n_channels * 4)
|
| 45 |
+
self.dense2 = nn.Linear(n_channels * 4, n_classes)
|
| 46 |
self.dropout = nn.Dropout(0.2)
|
| 47 |
|
| 48 |
def forward(self, x):
|
|
|
|
| 63 |
x = F.relu(x)
|
| 64 |
x = self.dropout(x)
|
| 65 |
x = self.dense2(x)
|
| 66 |
+
# x = nn.Sigmoid()(x)
|
| 67 |
|
| 68 |
return x
|
| 69 |
+
|
| 70 |
|
| 71 |
class ResBlock(nn.Module):
|
| 72 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
| 73 |
super().__init__()
|
| 74 |
# convolution
|
| 75 |
+
self.conv_1 = nn.Conv2d(
|
| 76 |
+
input_channels, output_channels, shape, stride=stride, padding=shape // 2
|
| 77 |
+
)
|
| 78 |
self.bn_1 = nn.BatchNorm2d(output_channels)
|
| 79 |
+
self.conv_2 = nn.Conv2d(
|
| 80 |
+
output_channels, output_channels, shape, padding=shape // 2
|
| 81 |
+
)
|
| 82 |
self.bn_2 = nn.BatchNorm2d(output_channels)
|
| 83 |
|
| 84 |
# residual
|
| 85 |
self.diff = False
|
| 86 |
if (stride != 1) or (input_channels != output_channels):
|
| 87 |
+
self.conv_3 = nn.Conv2d(
|
| 88 |
+
input_channels,
|
| 89 |
+
output_channels,
|
| 90 |
+
shape,
|
| 91 |
+
stride=stride,
|
| 92 |
+
padding=shape // 2,
|
| 93 |
+
)
|
| 94 |
self.bn_3 = nn.BatchNorm2d(output_channels)
|
| 95 |
self.diff = True
|
| 96 |
self.relu = nn.ReLU()
|
|
|
|
| 106 |
out = self.relu(out)
|
| 107 |
return out
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
class DancePredictor:
|
| 111 |
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
weight_path: str,
|
| 114 |
+
labels: list[str],
|
| 115 |
+
expected_duration=6,
|
| 116 |
threshold=0.5,
|
| 117 |
resample_frequency=16000,
|
| 118 |
+
device="cpu",
|
| 119 |
+
):
|
| 120 |
super().__init__()
|
| 121 |
+
|
| 122 |
self.expected_duration = expected_duration
|
| 123 |
self.threshold = threshold
|
| 124 |
self.resample_frequency = resample_frequency
|
| 125 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
| 126 |
+
resample_frequency * expected_duration
|
| 127 |
+
)
|
| 128 |
+
self.audio_to_spectrogram = lambda x: x # TODO: Fix
|
| 129 |
self.labels = np.array(labels)
|
| 130 |
self.device = device
|
| 131 |
self.model = self.get_model(weight_path)
|
| 132 |
|
| 133 |
+
def get_model(self, weight_path: str) -> nn.Module:
|
|
|
|
| 134 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
| 135 |
model = ResidualDancer(n_classes=len(self.labels))
|
| 136 |
for key in list(weights):
|
|
|
|
| 139 |
return model.to(self.device).eval()
|
| 140 |
|
| 141 |
@classmethod
|
| 142 |
+
def from_config(cls, config_path: str) -> "DancePredictor":
|
| 143 |
with open(config_path, "r") as f:
|
| 144 |
config = yaml.safe_load(f)
|
| 145 |
return DancePredictor(**config)
|
| 146 |
|
| 147 |
@torch.no_grad()
|
| 148 |
+
def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
|
| 149 |
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
| 150 |
+
waveform = waveform.transpose(1, 0)
|
| 151 |
elif len(waveform.shape) == 1:
|
| 152 |
waveform = np.expand_dims(waveform, 0)
|
| 153 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
| 154 |
+
waveform = torchaudio.functional.apply_codec(
|
| 155 |
+
waveform, sample_rate, "wav", channels_first=True
|
| 156 |
+
)
|
| 157 |
|
| 158 |
+
waveform = torchaudio.functional.resample(
|
| 159 |
+
waveform, sample_rate, self.resample_frequency
|
| 160 |
+
)
|
| 161 |
waveform = self.preprocess_waveform(waveform)
|
| 162 |
spectrogram = self.audio_to_spectrogram(waveform)
|
| 163 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
|
|
|
| 167 |
result_mask = results > self.threshold
|
| 168 |
probs = results[result_mask]
|
| 169 |
dances = self.labels[result_mask]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
return {dance: float(prob) for dance, prob in zip(dances, probs)}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def train_residual_dancer(config: dict):
|
| 175 |
+
TARGET_CLASSES = config["dance_ids"]
|
| 176 |
+
DEVICE = config["device"]
|
| 177 |
+
SEED = config["seed"]
|
| 178 |
+
pl.seed_everything(SEED, workers=True)
|
| 179 |
+
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
| 180 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
| 181 |
+
|
| 182 |
+
data = DanceDataModule(dataset, **config["data_module"])
|
| 183 |
+
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
| 184 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
| 185 |
+
criterion = nn.CrossEntropyLoss(label_weights)
|
| 186 |
+
|
| 187 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
| 188 |
+
callbacks = [
|
| 189 |
+
# cb.LearningRateFinder(update_attr=True),
|
| 190 |
+
cb.EarlyStopping("val/loss", patience=5),
|
| 191 |
+
cb.StochasticWeightAveraging(1e-2),
|
| 192 |
+
cb.RichProgressBar(),
|
| 193 |
+
cb.DeviceStatsMonitor(),
|
| 194 |
+
]
|
| 195 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
| 196 |
+
trainer.fit(train_env, datamodule=data)
|
| 197 |
+
trainer.test(train_env, datamodule=data)
|
models/training_environment.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.utils import calculate_metrics
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TrainingEnvironment(pl.LightningModule):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
model: nn.Module,
|
| 13 |
+
criterion: nn.Module,
|
| 14 |
+
config: dict,
|
| 15 |
+
learning_rate=1e-4,
|
| 16 |
+
log_spectrograms=False,
|
| 17 |
+
*args,
|
| 18 |
+
**kwargs,
|
| 19 |
+
):
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
self.model = model
|
| 22 |
+
self.criterion = criterion
|
| 23 |
+
self.learning_rate = learning_rate
|
| 24 |
+
self.log_spectrograms = log_spectrograms
|
| 25 |
+
self.config = config
|
| 26 |
+
self.has_multi_label_predictions = (
|
| 27 |
+
not type(criterion).__name__ == "CrossEntropyLoss"
|
| 28 |
+
)
|
| 29 |
+
self.save_hyperparameters(
|
| 30 |
+
{
|
| 31 |
+
"model": type(model).__name__,
|
| 32 |
+
"loss": type(criterion).__name__,
|
| 33 |
+
"config": config,
|
| 34 |
+
**kwargs,
|
| 35 |
+
}
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def training_step(
|
| 39 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
features, labels = batch
|
| 42 |
+
outputs = self.model(features)
|
| 43 |
+
loss = self.criterion(outputs, labels)
|
| 44 |
+
metrics = calculate_metrics(
|
| 45 |
+
outputs,
|
| 46 |
+
labels,
|
| 47 |
+
prefix="train/",
|
| 48 |
+
multi_label=self.has_multi_label_predictions,
|
| 49 |
+
)
|
| 50 |
+
self.log_dict(metrics, prog_bar=True)
|
| 51 |
+
# Log spectrograms
|
| 52 |
+
if self.log_spectrograms and batch_index % 100 == 0:
|
| 53 |
+
tensorboard = self.logger.experiment
|
| 54 |
+
img_index = torch.randint(0, len(features), (1,)).item()
|
| 55 |
+
img = features[img_index][0]
|
| 56 |
+
img = (img - img.min()) / (img.max() - img.min())
|
| 57 |
+
tensorboard.add_image(
|
| 58 |
+
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
|
| 59 |
+
)
|
| 60 |
+
return loss
|
| 61 |
+
|
| 62 |
+
def validation_step(
|
| 63 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
| 64 |
+
):
|
| 65 |
+
x, y = batch
|
| 66 |
+
preds = self.model(x)
|
| 67 |
+
metrics = calculate_metrics(
|
| 68 |
+
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
| 69 |
+
)
|
| 70 |
+
metrics["val/loss"] = self.criterion(preds, y)
|
| 71 |
+
self.log_dict(metrics, prog_bar=True)
|
| 72 |
+
|
| 73 |
+
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
| 74 |
+
x, y = batch
|
| 75 |
+
preds = self.model(x)
|
| 76 |
+
self.log_dict(
|
| 77 |
+
calculate_metrics(
|
| 78 |
+
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
|
| 79 |
+
),
|
| 80 |
+
prog_bar=True,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def configure_optimizers(self):
|
| 84 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 85 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
|
| 86 |
+
return {
|
| 87 |
+
"optimizer": optimizer,
|
| 88 |
+
"lr_scheduler": scheduler,
|
| 89 |
+
"monitor": "val/loss",
|
| 90 |
+
}
|
models/utils.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class LabelWeightedBCELoss(nn.Module):
|
| 7 |
"""
|
| 8 |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
|
| 9 |
Allows for the weighing of each probability distribution wrt loss.
|
| 10 |
"""
|
| 11 |
-
|
|
|
|
| 12 |
super().__init__()
|
| 13 |
self.label_weights = label_weights
|
| 14 |
|
|
@@ -17,46 +23,67 @@ class LabelWeightedBCELoss(nn.Module):
|
|
| 17 |
self.reduction = torch.mean
|
| 18 |
case "sum":
|
| 19 |
self.reduction = torch.sum
|
| 20 |
-
|
| 21 |
-
def _log(self,x:torch.Tensor) -> torch.Tensor:
|
| 22 |
return torch.clamp_min(torch.log(x), -100)
|
| 23 |
|
| 24 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 25 |
-
losses = -self.label_weights * (
|
|
|
|
|
|
|
| 26 |
return self.reduction(losses)
|
| 27 |
|
| 28 |
|
| 29 |
# TODO: Code a onehot
|
| 30 |
|
| 31 |
|
| 32 |
-
def calculate_metrics(
|
|
|
|
|
|
|
| 33 |
target = target.detach().cpu().numpy()
|
| 34 |
pred = pred.detach().cpu().numpy()
|
| 35 |
params = {
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
class EarlyStopping:
|
| 50 |
def __init__(self, patience=0):
|
| 51 |
self.patience = patience
|
| 52 |
self.last_measure = np.inf
|
| 53 |
self.consecutive_increase = 0
|
| 54 |
-
|
| 55 |
def step(self, val) -> bool:
|
| 56 |
if self.last_measure <= val:
|
| 57 |
-
self.consecutive_increase +=1
|
| 58 |
else:
|
| 59 |
self.consecutive_increase = 0
|
| 60 |
self.last_measure = val
|
| 61 |
|
| 62 |
-
return self.patience < self.consecutive_increase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
+
import evaluate
|
| 5 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
| 6 |
|
| 7 |
+
|
| 8 |
+
accuracy = evaluate.load("accuracy")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
class LabelWeightedBCELoss(nn.Module):
|
| 12 |
"""
|
| 13 |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
|
| 14 |
Allows for the weighing of each probability distribution wrt loss.
|
| 15 |
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, label_weights: torch.Tensor, reduction="mean"):
|
| 18 |
super().__init__()
|
| 19 |
self.label_weights = label_weights
|
| 20 |
|
|
|
|
| 23 |
self.reduction = torch.mean
|
| 24 |
case "sum":
|
| 25 |
self.reduction = torch.sum
|
| 26 |
+
|
| 27 |
+
def _log(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
return torch.clamp_min(torch.log(x), -100)
|
| 29 |
|
| 30 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
losses = -self.label_weights * (
|
| 32 |
+
target * self._log(input) + (1 - target) * self._log(1 - input)
|
| 33 |
+
)
|
| 34 |
return self.reduction(losses)
|
| 35 |
|
| 36 |
|
| 37 |
# TODO: Code a onehot
|
| 38 |
|
| 39 |
|
| 40 |
+
def calculate_metrics(
|
| 41 |
+
pred, target, threshold=0.5, prefix="", multi_label=True
|
| 42 |
+
) -> dict[str, torch.Tensor]:
|
| 43 |
target = target.detach().cpu().numpy()
|
| 44 |
pred = pred.detach().cpu().numpy()
|
| 45 |
params = {
|
| 46 |
+
"y_true": target if multi_label else target.argmax(1),
|
| 47 |
+
"y_pred": np.array(pred > threshold, dtype=float)
|
| 48 |
+
if multi_label
|
| 49 |
+
else pred.argmax(1),
|
| 50 |
+
"zero_division": 0,
|
| 51 |
+
"average": "macro",
|
| 52 |
+
}
|
| 53 |
+
metrics = {
|
| 54 |
+
"precision": precision_score(**params),
|
| 55 |
+
"recall": recall_score(**params),
|
| 56 |
+
"f1": f1_score(**params),
|
| 57 |
+
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
| 58 |
+
}
|
| 59 |
+
return {
|
| 60 |
+
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
|
| 64 |
class EarlyStopping:
|
| 65 |
def __init__(self, patience=0):
|
| 66 |
self.patience = patience
|
| 67 |
self.last_measure = np.inf
|
| 68 |
self.consecutive_increase = 0
|
| 69 |
+
|
| 70 |
def step(self, val) -> bool:
|
| 71 |
if self.last_measure <= val:
|
| 72 |
+
self.consecutive_increase += 1
|
| 73 |
else:
|
| 74 |
self.consecutive_increase = 0
|
| 75 |
self.last_measure = val
|
| 76 |
|
| 77 |
+
return self.patience < self.consecutive_increase
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
|
| 81 |
+
id2label = {str(i): label for i, label in enumerate(labels)}
|
| 82 |
+
label2id = {label: str(i) for i, label in enumerate(labels)}
|
| 83 |
+
|
| 84 |
+
return id2label, label2id
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_hf_metrics(eval_pred):
|
| 88 |
+
predictions = np.argmax(eval_pred.predictions, axis=1)
|
| 89 |
+
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
models/wav2vec2.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch.utils.data import random_split
|
| 5 |
+
from transformers import AutoFeatureExtractor
|
| 6 |
+
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
|
| 7 |
+
|
| 8 |
+
from preprocessing.dataset import (
|
| 9 |
+
HuggingFaceDatasetWrapper,
|
| 10 |
+
BestBallroomDataset,
|
| 11 |
+
get_datasets,
|
| 12 |
+
)
|
| 13 |
+
from preprocessing.pipelines import WaveformTrainingPipeline
|
| 14 |
+
|
| 15 |
+
from .utils import get_id_label_mapping, compute_hf_metrics
|
| 16 |
+
|
| 17 |
+
MODEL_CHECKPOINT = "facebook/wav2vec2-base"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Wav2VecFeatureExtractor:
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
self.waveform_pipeline = WaveformTrainingPipeline()
|
| 23 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 24 |
+
MODEL_CHECKPOINT,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def __call__(self, waveform) -> Any:
|
| 28 |
+
waveform = self.waveform_pipeline(waveform)
|
| 29 |
+
return self.feature_extractor(
|
| 30 |
+
waveform, sampling_rate=self.feature_extractor.sampling_rate
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def __getattr__(self, attr):
|
| 34 |
+
return getattr(self.feature_extractor, attr)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def train_wav_model(config: dict):
|
| 38 |
+
TARGET_CLASSES = config["dance_ids"]
|
| 39 |
+
DEVICE = config["device"]
|
| 40 |
+
SEED = config["seed"]
|
| 41 |
+
OUTPUT_DIR = "models/weights/wav2vec2"
|
| 42 |
+
batch_size = config["data_module"]["batch_size"]
|
| 43 |
+
epochs = config["trainer"]["min_epochs"]
|
| 44 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
| 45 |
+
pl.seed_everything(SEED, workers=True)
|
| 46 |
+
dataset = get_datasets(config["datasets"])
|
| 47 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
| 48 |
+
test_proportion = config["data_module"]["test_proportion"]
|
| 49 |
+
train_proporition = 1 - test_proportion
|
| 50 |
+
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
|
| 51 |
+
feature_extractor = Wav2VecFeatureExtractor()
|
| 52 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
| 53 |
+
MODEL_CHECKPOINT,
|
| 54 |
+
num_labels=len(TARGET_CLASSES),
|
| 55 |
+
label2id=label2id,
|
| 56 |
+
id2label=id2label,
|
| 57 |
+
ignore_mismatched_sizes=True,
|
| 58 |
+
).to(DEVICE)
|
| 59 |
+
training_args = TrainingArguments(
|
| 60 |
+
output_dir=OUTPUT_DIR,
|
| 61 |
+
evaluation_strategy="epoch",
|
| 62 |
+
save_strategy="epoch",
|
| 63 |
+
learning_rate=3e-5,
|
| 64 |
+
per_device_train_batch_size=batch_size,
|
| 65 |
+
gradient_accumulation_steps=5,
|
| 66 |
+
per_device_eval_batch_size=batch_size,
|
| 67 |
+
num_train_epochs=epochs,
|
| 68 |
+
warmup_ratio=0.1,
|
| 69 |
+
logging_steps=10,
|
| 70 |
+
load_best_model_at_end=True,
|
| 71 |
+
metric_for_best_model="accuracy",
|
| 72 |
+
push_to_hub=False,
|
| 73 |
+
use_mps_device=DEVICE == "mps",
|
| 74 |
+
)
|
| 75 |
+
trainer = Trainer(
|
| 76 |
+
model=model,
|
| 77 |
+
args=training_args,
|
| 78 |
+
train_dataset=train_ds,
|
| 79 |
+
eval_dataset=test_ds,
|
| 80 |
+
tokenizer=feature_extractor,
|
| 81 |
+
compute_metrics=compute_hf_metrics,
|
| 82 |
+
)
|
| 83 |
+
trainer.train()
|
| 84 |
+
return model
|
preprocessing/dataset.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
-
from torch.utils.data import Dataset, DataLoader, random_split
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
import torchaudio as ta
|
| 6 |
-
from .pipelines import AudioTrainingPipeline
|
| 7 |
import pytorch_lightning as pl
|
| 8 |
-
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class SongDataset(Dataset):
|
|
@@ -17,60 +23,67 @@ class SongDataset(Dataset):
|
|
| 17 |
self,
|
| 18 |
audio_paths: list[str],
|
| 19 |
dance_labels: list[np.ndarray],
|
| 20 |
-
|
| 21 |
audio_window_duration=6, # seconds
|
| 22 |
-
audio_window_jitter=
|
| 23 |
-
audio_pipeline_kwargs={},
|
| 24 |
-
resample_frequency=16000,
|
| 25 |
):
|
| 26 |
-
assert (
|
| 27 |
-
audio_duration % audio_window_duration == 0
|
| 28 |
-
), "Audio window should divide duration evenly."
|
| 29 |
assert (
|
| 30 |
audio_window_duration > audio_window_jitter
|
| 31 |
), "Jitter should be a small fraction of the audio window duration."
|
| 32 |
|
| 33 |
self.audio_paths = audio_paths
|
| 34 |
self.dance_labels = dance_labels
|
| 35 |
-
|
| 36 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 37 |
self.audio_window_duration = int(audio_window_duration)
|
|
|
|
| 38 |
self.audio_window_jitter = audio_window_jitter
|
| 39 |
-
self.audio_duration = int(audio_duration)
|
| 40 |
-
|
| 41 |
-
self.audio_pipeline = AudioTrainingPipeline(
|
| 42 |
-
self.sample_rate,
|
| 43 |
-
resample_frequency,
|
| 44 |
-
audio_window_duration,
|
| 45 |
-
**audio_pipeline_kwargs,
|
| 46 |
-
)
|
| 47 |
|
| 48 |
def __len__(self):
|
| 49 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
spectrogram = self.audio_pipeline(waveform)
|
| 57 |
|
|
|
|
| 58 |
dance_labels = self._label_from_index(idx)
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
def
|
| 70 |
-
|
|
|
|
| 71 |
|
| 72 |
def _backtrace_audio_path(self, index: int) -> str:
|
| 73 |
-
return self.audio_paths[self.
|
| 74 |
|
| 75 |
def _validate_output(self, x, y):
|
| 76 |
is_finite = not torch.any(torch.isinf(x))
|
|
@@ -80,16 +93,18 @@ class SongDataset(Dataset):
|
|
| 80 |
return all((is_finite, is_numerical, has_data, is_binary))
|
| 81 |
|
| 82 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
| 87 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
| 88 |
jitter = int(
|
| 89 |
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
|
| 90 |
)
|
| 91 |
-
frame_offset = (
|
| 92 |
-
frame_index * self.audio_window_duration * self.sample_rate
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
num_frames = self.sample_rate * self.audio_window_duration
|
| 95 |
waveform, sample_rate = ta.load(
|
|
@@ -101,41 +116,21 @@ class SongDataset(Dataset):
|
|
| 101 |
return waveform
|
| 102 |
|
| 103 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
| 104 |
-
return torch.from_numpy(self.dance_labels[self.
|
| 105 |
|
| 106 |
|
| 107 |
-
class
|
| 108 |
"""
|
| 109 |
-
|
| 110 |
"""
|
| 111 |
|
| 112 |
-
def __init__(self, *args,
|
| 113 |
super().__init__(*args, **kwargs)
|
| 114 |
-
self.
|
| 115 |
-
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
|
| 116 |
self.pipeline = []
|
| 117 |
|
| 118 |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 119 |
-
|
| 120 |
-
assert (
|
| 121 |
-
waveform.shape[1] > 10
|
| 122 |
-
), f"No data found: {self._backtrace_audio_path(idx)}"
|
| 123 |
-
# resample the waveform
|
| 124 |
-
waveform = self.resampler(waveform)
|
| 125 |
-
|
| 126 |
-
waveform = waveform.mean(0)
|
| 127 |
-
|
| 128 |
-
dance_labels = self._label_from_index(idx)
|
| 129 |
-
return waveform, dance_labels
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
| 133 |
-
def __init__(self, *args, **kwargs):
|
| 134 |
-
super().__init__(*args, **kwargs)
|
| 135 |
-
self.pipeline = []
|
| 136 |
-
|
| 137 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 138 |
-
x, y = super().__getitem__(idx)
|
| 139 |
if len(self.pipeline) > 0:
|
| 140 |
for fn in self.pipeline:
|
| 141 |
x = fn(x)
|
|
@@ -146,59 +141,158 @@ class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
|
| 146 |
"label": dance_labels,
|
| 147 |
}
|
| 148 |
|
| 149 |
-
def
|
|
|
|
|
|
|
|
|
|
| 150 |
"""
|
| 151 |
-
|
| 152 |
"""
|
| 153 |
self.pipeline.append(fn)
|
| 154 |
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
class DanceDataModule(pl.LightningDataModule):
|
| 157 |
def __init__(
|
| 158 |
self,
|
| 159 |
-
|
| 160 |
-
song_audio_path="data/samples",
|
| 161 |
test_proportion=0.15,
|
| 162 |
val_proportion=0.1,
|
| 163 |
target_classes: list[str] = None,
|
| 164 |
-
min_votes=1,
|
| 165 |
batch_size: int = 64,
|
| 166 |
num_workers=10,
|
| 167 |
-
dataset_cls=None,
|
| 168 |
-
dataset_kwargs={},
|
| 169 |
):
|
| 170 |
super().__init__()
|
| 171 |
-
self.song_data_path = song_data_path
|
| 172 |
-
self.song_audio_path = song_audio_path
|
| 173 |
self.val_proportion = val_proportion
|
| 174 |
self.test_proportion = test_proportion
|
| 175 |
self.train_proportion = 1.0 - test_proportion - val_proportion
|
| 176 |
self.target_classes = target_classes
|
| 177 |
self.batch_size = batch_size
|
| 178 |
self.num_workers = num_workers
|
| 179 |
-
self.
|
| 180 |
-
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
|
| 181 |
-
|
| 182 |
-
df = pd.read_csv(song_data_path)
|
| 183 |
-
self.x, self.y = get_examples(
|
| 184 |
-
df,
|
| 185 |
-
self.song_audio_path,
|
| 186 |
-
class_list=self.target_classes,
|
| 187 |
-
multi_label=True,
|
| 188 |
-
min_votes=min_votes,
|
| 189 |
-
)
|
| 190 |
|
| 191 |
def setup(self, stage: str):
|
| 192 |
-
|
| 193 |
-
|
| 194 |
[self.train_proportion, self.val_proportion, self.test_proportion],
|
| 195 |
)
|
| 196 |
-
self.train_ds = self._dataset_from_indices(train_i)
|
| 197 |
-
self.val_ds = self._dataset_from_indices(val_i)
|
| 198 |
-
self.test_ds = self._dataset_from_indices(test_i)
|
| 199 |
-
|
| 200 |
-
def _dataset_from_indices(self, idx: list[int]) -> SongDataset:
|
| 201 |
-
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
| 202 |
|
| 203 |
def train_dataloader(self):
|
| 204 |
return DataLoader(
|
|
@@ -210,110 +304,48 @@ class DanceDataModule(pl.LightningDataModule):
|
|
| 210 |
|
| 211 |
def val_dataloader(self):
|
| 212 |
return DataLoader(
|
| 213 |
-
self.val_ds,
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
def test_dataloader(self):
|
| 217 |
return DataLoader(
|
| 218 |
-
self.test_ds,
|
|
|
|
|
|
|
| 219 |
)
|
| 220 |
|
| 221 |
def get_label_weights(self):
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
def preprocess_inputs(self, x):
|
| 253 |
-
device = x.device
|
| 254 |
-
x = list(x.squeeze(1).cpu().numpy())
|
| 255 |
-
x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000)
|
| 256 |
-
return x["input_values"].to(device)
|
| 257 |
-
|
| 258 |
-
def training_step(
|
| 259 |
-
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
| 260 |
-
) -> torch.Tensor:
|
| 261 |
-
features, labels = batch
|
| 262 |
-
features = self.preprocess_inputs(features)
|
| 263 |
-
outputs = self.model(features).logits
|
| 264 |
-
outputs = nn.Sigmoid()(
|
| 265 |
-
outputs
|
| 266 |
-
) # good for multi label classification, should be softmax otherwise
|
| 267 |
-
loss = self.criterion(outputs, labels)
|
| 268 |
-
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
| 269 |
-
self.log_dict(metrics, prog_bar=True)
|
| 270 |
-
return loss
|
| 271 |
-
|
| 272 |
-
def validation_step(
|
| 273 |
-
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
| 274 |
-
):
|
| 275 |
-
x, y = batch
|
| 276 |
-
x = self.preprocess_inputs(x)
|
| 277 |
-
preds = self.model(x).logits
|
| 278 |
-
preds = nn.Sigmoid()(preds)
|
| 279 |
-
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
| 280 |
-
metrics["val/loss"] = self.criterion(preds, y)
|
| 281 |
-
self.log_dict(metrics, prog_bar=True)
|
| 282 |
-
|
| 283 |
-
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
| 284 |
-
x, y = batch
|
| 285 |
-
x = self.preprocess_inputs(x)
|
| 286 |
-
preds = self.model(x).logits
|
| 287 |
-
preds = nn.Sigmoid()(preds)
|
| 288 |
-
self.log_dict(
|
| 289 |
-
calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
def configure_optimizers(self):
|
| 293 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 294 |
-
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
| 295 |
-
return [optimizer]
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def calculate_metrics(
|
| 299 |
-
pred, target, threshold=0.5, prefix="", multi_label=True
|
| 300 |
-
) -> dict[str, torch.Tensor]:
|
| 301 |
-
target = target.detach().cpu().numpy()
|
| 302 |
-
pred = pred.detach().cpu().numpy()
|
| 303 |
-
params = {
|
| 304 |
-
"y_true": target if multi_label else target.argmax(1),
|
| 305 |
-
"y_pred": np.array(pred > threshold, dtype=float)
|
| 306 |
-
if multi_label
|
| 307 |
-
else pred.argmax(1),
|
| 308 |
-
"zero_division": 0,
|
| 309 |
-
"average": "macro",
|
| 310 |
-
}
|
| 311 |
-
metrics = {
|
| 312 |
-
"precision": precision_score(**params),
|
| 313 |
-
"recall": recall_score(**params),
|
| 314 |
-
"f1": f1_score(**params),
|
| 315 |
-
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
| 316 |
-
}
|
| 317 |
-
return {
|
| 318 |
-
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
|
| 319 |
-
}
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any
|
| 4 |
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torchaudio as ta
|
|
|
|
| 9 |
import pytorch_lightning as pl
|
| 10 |
+
|
| 11 |
+
from preprocessing.preprocess import (
|
| 12 |
+
fix_dance_rating_counts,
|
| 13 |
+
get_unique_labels,
|
| 14 |
+
has_valid_audio,
|
| 15 |
+
url_to_filename,
|
| 16 |
+
vectorize_label_probs,
|
| 17 |
+
vectorize_multi_label,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
|
| 21 |
class SongDataset(Dataset):
|
|
|
|
| 23 |
self,
|
| 24 |
audio_paths: list[str],
|
| 25 |
dance_labels: list[np.ndarray],
|
| 26 |
+
audio_start_offset=6, # seconds
|
| 27 |
audio_window_duration=6, # seconds
|
| 28 |
+
audio_window_jitter=1.0, # seconds
|
|
|
|
|
|
|
| 29 |
):
|
|
|
|
|
|
|
|
|
|
| 30 |
assert (
|
| 31 |
audio_window_duration > audio_window_jitter
|
| 32 |
), "Jitter should be a small fraction of the audio window duration."
|
| 33 |
|
| 34 |
self.audio_paths = audio_paths
|
| 35 |
self.dance_labels = dance_labels
|
| 36 |
+
audio_metadata = [ta.info(audio) for audio in audio_paths]
|
| 37 |
+
self.audio_durations = [
|
| 38 |
+
meta.num_frames / meta.sample_rate for meta in audio_metadata
|
| 39 |
+
]
|
| 40 |
+
self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
|
| 41 |
self.audio_window_duration = int(audio_window_duration)
|
| 42 |
+
self.audio_start_offset = audio_start_offset
|
| 43 |
self.audio_window_jitter = audio_window_jitter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def __len__(self):
|
| 46 |
+
return int(
|
| 47 |
+
sum(
|
| 48 |
+
max(duration - self.audio_start_offset, 0) // self.audio_window_duration
|
| 49 |
+
for duration in self.audio_durations
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
|
| 53 |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
+
if isinstance(idx, list):
|
| 55 |
+
return [
|
| 56 |
+
(self._waveform_from_index(i), self._label_from_index(i)) for i in idx
|
| 57 |
+
]
|
|
|
|
| 58 |
|
| 59 |
+
waveform = self._waveform_from_index(idx)
|
| 60 |
dance_labels = self._label_from_index(idx)
|
| 61 |
+
return waveform, dance_labels
|
| 62 |
|
| 63 |
+
def _idx2audio_idx(self, idx: int) -> int:
|
| 64 |
+
return self._get_audio_loc_from_idx(idx)[0]
|
| 65 |
+
|
| 66 |
+
def _get_audio_loc_from_idx(self, idx: int) -> tuple[int, int]:
|
| 67 |
+
"""
|
| 68 |
+
Converts dataset index to the indices that reference the target audio path
|
| 69 |
+
and window offset.
|
| 70 |
+
"""
|
| 71 |
+
total_slices = 0
|
| 72 |
+
for audio_index, duration in enumerate(self.audio_durations):
|
| 73 |
+
audio_slices = max(
|
| 74 |
+
(duration - self.audio_start_offset) // self.audio_window_duration, 1
|
| 75 |
+
)
|
| 76 |
+
if total_slices + audio_slices > idx:
|
| 77 |
+
frame_index = idx - total_slices
|
| 78 |
+
return audio_index, frame_index
|
| 79 |
+
total_slices += audio_slices
|
| 80 |
|
| 81 |
+
def get_label_weights(self):
|
| 82 |
+
n_examples, n_classes = self.dance_labels.shape
|
| 83 |
+
return torch.from_numpy(n_examples / (n_classes * sum(self.dance_labels)))
|
| 84 |
|
| 85 |
def _backtrace_audio_path(self, index: int) -> str:
|
| 86 |
+
return self.audio_paths[self._idx2audio_idx(index)]
|
| 87 |
|
| 88 |
def _validate_output(self, x, y):
|
| 89 |
is_finite = not torch.any(torch.isinf(x))
|
|
|
|
| 93 |
return all((is_finite, is_numerical, has_data, is_binary))
|
| 94 |
|
| 95 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
| 96 |
+
audio_index, frame_index = self._get_audio_loc_from_idx(idx)
|
| 97 |
+
audio_filepath = self.audio_paths[audio_index]
|
| 98 |
+
num_windows = self.audio_durations[audio_index] // self.audio_window_duration
|
| 99 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
| 100 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
| 101 |
jitter = int(
|
| 102 |
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
|
| 103 |
)
|
| 104 |
+
frame_offset = int(
|
| 105 |
+
frame_index * self.audio_window_duration * self.sample_rate
|
| 106 |
+
+ jitter
|
| 107 |
+
+ self.audio_start_offset * self.sample_rate
|
| 108 |
)
|
| 109 |
num_frames = self.sample_rate * self.audio_window_duration
|
| 110 |
waveform, sample_rate = ta.load(
|
|
|
|
| 116 |
return waveform
|
| 117 |
|
| 118 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
| 119 |
+
return torch.from_numpy(self.dance_labels[self._idx2audio_idx(idx)])
|
| 120 |
|
| 121 |
|
| 122 |
+
class HuggingFaceDatasetWrapper(Dataset):
|
| 123 |
"""
|
| 124 |
+
Makes a standard PyTorch Dataset compatible with a HuggingFace Trainer.
|
| 125 |
"""
|
| 126 |
|
| 127 |
+
def __init__(self, dataset, *args, **kwargs):
|
| 128 |
super().__init__(*args, **kwargs)
|
| 129 |
+
self.dataset = dataset
|
|
|
|
| 130 |
self.pipeline = []
|
| 131 |
|
| 132 |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 133 |
+
x, y = self.dataset[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
if len(self.pipeline) > 0:
|
| 135 |
for fn in self.pipeline:
|
| 136 |
x = fn(x)
|
|
|
|
| 141 |
"label": dance_labels,
|
| 142 |
}
|
| 143 |
|
| 144 |
+
def __len__(self):
|
| 145 |
+
return len(self.dataset)
|
| 146 |
+
|
| 147 |
+
def append_to_pipeline(self, fn):
|
| 148 |
"""
|
| 149 |
+
Adds a preprocessing step to the dataset.
|
| 150 |
"""
|
| 151 |
self.pipeline.append(fn)
|
| 152 |
|
| 153 |
|
| 154 |
+
class BestBallroomDataset(Dataset):
|
| 155 |
+
def __init__(
|
| 156 |
+
self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
|
| 157 |
+
) -> None:
|
| 158 |
+
super().__init__()
|
| 159 |
+
song_paths, labels = self.get_examples(audio_dir, class_list)
|
| 160 |
+
self.song_dataset = SongDataset(song_paths, labels, **kwargs)
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
| 163 |
+
return self.song_dataset[index]
|
| 164 |
+
|
| 165 |
+
def __len__(self):
|
| 166 |
+
return len(self.song_dataset)
|
| 167 |
+
|
| 168 |
+
def get_examples(self, audio_dir, class_list=None):
|
| 169 |
+
dances = set(
|
| 170 |
+
f
|
| 171 |
+
for f in os.listdir(audio_dir)
|
| 172 |
+
if os.path.isdir(os.path.join(audio_dir, f))
|
| 173 |
+
)
|
| 174 |
+
common_dances = dances
|
| 175 |
+
if class_list is not None:
|
| 176 |
+
common_dances = dances & set(class_list)
|
| 177 |
+
dances = class_list
|
| 178 |
+
dances = np.array(sorted(dances))
|
| 179 |
+
song_paths = []
|
| 180 |
+
labels = []
|
| 181 |
+
for dance in common_dances:
|
| 182 |
+
dance_label = (dances == dance).astype("float32")
|
| 183 |
+
folder_path = os.path.join(audio_dir, dance)
|
| 184 |
+
folder_contents = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
|
| 185 |
+
song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
|
| 186 |
+
labels.extend([dance_label] * len(folder_contents))
|
| 187 |
+
|
| 188 |
+
return np.array(song_paths), np.stack(labels)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Music4DanceDataset(Dataset):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
song_data_path,
|
| 195 |
+
song_audio_path,
|
| 196 |
+
class_list=None,
|
| 197 |
+
multi_label=True,
|
| 198 |
+
min_votes=1,
|
| 199 |
+
**kwargs,
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
df = pd.read_csv(song_data_path)
|
| 203 |
+
song_paths, labels = get_music4dance_examples(
|
| 204 |
+
df,
|
| 205 |
+
song_audio_path,
|
| 206 |
+
class_list=class_list,
|
| 207 |
+
multi_label=multi_label,
|
| 208 |
+
min_votes=min_votes,
|
| 209 |
+
)
|
| 210 |
+
self.song_dataset = SongDataset(song_paths, labels, **kwargs)
|
| 211 |
+
|
| 212 |
+
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
| 213 |
+
return self.song_dataset[index]
|
| 214 |
+
|
| 215 |
+
def __len__(self):
|
| 216 |
+
return len(self.song_dataset)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_music4dance_examples(
|
| 220 |
+
df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
|
| 221 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 222 |
+
sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
|
| 223 |
+
sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
| 224 |
+
if class_list is not None:
|
| 225 |
+
class_list = set(class_list)
|
| 226 |
+
sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
|
| 227 |
+
lambda labels: {k: v for k, v in labels.items() if k in class_list}
|
| 228 |
+
if not pd.isna(labels)
|
| 229 |
+
and any(label in class_list and amt > 0 for label, amt in labels.items())
|
| 230 |
+
else np.nan
|
| 231 |
+
)
|
| 232 |
+
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
| 233 |
+
vote_mask = sampled_songs["DanceRating"].apply(
|
| 234 |
+
lambda dances: any(votes >= min_votes for votes in dances.values())
|
| 235 |
+
)
|
| 236 |
+
sampled_songs = sampled_songs[vote_mask]
|
| 237 |
+
labels = sampled_songs["DanceRating"].apply(
|
| 238 |
+
lambda dances: {
|
| 239 |
+
dance: votes for dance, votes in dances.items() if votes >= min_votes
|
| 240 |
+
}
|
| 241 |
+
)
|
| 242 |
+
unique_labels = np.array(get_unique_labels(labels))
|
| 243 |
+
vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
|
| 244 |
+
labels = labels.apply(lambda i: vectorizer(i, unique_labels))
|
| 245 |
+
|
| 246 |
+
audio_paths = [
|
| 247 |
+
os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
return np.array(audio_paths), np.stack(labels)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class PipelinedDataset(Dataset):
|
| 254 |
+
"""
|
| 255 |
+
Adds a feature extractor preprocessing step to a dataset.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, dataset, feature_extractor):
|
| 259 |
+
self._data = dataset
|
| 260 |
+
self.feature_extractor = feature_extractor
|
| 261 |
+
|
| 262 |
+
def __len__(self):
|
| 263 |
+
return len(self._data)
|
| 264 |
+
|
| 265 |
+
def __getitem__(self, index):
|
| 266 |
+
sample, label = self._data[index]
|
| 267 |
+
|
| 268 |
+
features = self.feature_extractor(sample)
|
| 269 |
+
return features, label
|
| 270 |
+
|
| 271 |
+
|
| 272 |
class DanceDataModule(pl.LightningDataModule):
|
| 273 |
def __init__(
|
| 274 |
self,
|
| 275 |
+
dataset: Dataset,
|
|
|
|
| 276 |
test_proportion=0.15,
|
| 277 |
val_proportion=0.1,
|
| 278 |
target_classes: list[str] = None,
|
|
|
|
| 279 |
batch_size: int = 64,
|
| 280 |
num_workers=10,
|
|
|
|
|
|
|
| 281 |
):
|
| 282 |
super().__init__()
|
|
|
|
|
|
|
| 283 |
self.val_proportion = val_proportion
|
| 284 |
self.test_proportion = test_proportion
|
| 285 |
self.train_proportion = 1.0 - test_proportion - val_proportion
|
| 286 |
self.target_classes = target_classes
|
| 287 |
self.batch_size = batch_size
|
| 288 |
self.num_workers = num_workers
|
| 289 |
+
self.dataset = dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
def setup(self, stage: str):
|
| 292 |
+
self.train_ds, self.val_ds, self.test_ds = random_split(
|
| 293 |
+
self.dataset,
|
| 294 |
[self.train_proportion, self.val_proportion, self.test_proportion],
|
| 295 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
def train_dataloader(self):
|
| 298 |
return DataLoader(
|
|
|
|
| 304 |
|
| 305 |
def val_dataloader(self):
|
| 306 |
return DataLoader(
|
| 307 |
+
self.val_ds,
|
| 308 |
+
batch_size=self.batch_size,
|
| 309 |
+
num_workers=self.num_workers,
|
| 310 |
)
|
| 311 |
|
| 312 |
def test_dataloader(self):
|
| 313 |
return DataLoader(
|
| 314 |
+
self.test_ds,
|
| 315 |
+
batch_size=self.batch_size,
|
| 316 |
+
num_workers=self.num_workers,
|
| 317 |
)
|
| 318 |
|
| 319 |
def get_label_weights(self):
|
| 320 |
+
weights = [
|
| 321 |
+
ds.song_dataset.get_label_weights() for ds in self.dataset._data.datasets
|
| 322 |
+
]
|
| 323 |
+
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
| 324 |
|
| 325 |
|
| 326 |
+
def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
|
| 327 |
+
"""
|
| 328 |
+
Estimates the mean and standard deviations of the a dataset.
|
| 329 |
+
"""
|
| 330 |
+
sample_size = int(np.ceil((zscore**2 * p * (1 - p)) / (moe**2)))
|
| 331 |
+
sample_indices = np.random.choice(
|
| 332 |
+
np.arange(len(dataset)), size=sample_size, replace=False
|
| 333 |
+
)
|
| 334 |
+
mean = 0
|
| 335 |
+
std = 0
|
| 336 |
+
for i in sample_indices:
|
| 337 |
+
features = dataset[i][0]
|
| 338 |
+
mean += features.mean().item()
|
| 339 |
+
std += features.std().item()
|
| 340 |
+
print("std", std / sample_size)
|
| 341 |
+
print("mean", mean / sample_size)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
|
| 345 |
+
datasets = []
|
| 346 |
+
for dataset_path, kwargs in dataset_config.items():
|
| 347 |
+
module_name, class_name = dataset_path.rsplit(".", 1)
|
| 348 |
+
module = importlib.import_module(module_name)
|
| 349 |
+
ProvidedDataset = getattr(module, class_name)
|
| 350 |
+
datasets.append(ProvidedDataset(**kwargs))
|
| 351 |
+
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocessing/pipelines.py
CHANGED
|
@@ -3,29 +3,26 @@ import torchaudio
|
|
| 3 |
from torchaudio import transforms as taT, functional as taF
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
super().__init__()
|
| 17 |
self.input_freq = input_freq
|
| 18 |
self.snr_mean = snr_mean
|
| 19 |
-
self.mask_count = mask_count
|
| 20 |
self.noise = self.get_noise(noise_path)
|
| 21 |
-
self.
|
| 22 |
-
self.
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
)
|
| 26 |
-
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
| 27 |
-
self.time_mask = taT.TimeMasking(time_mask_size)
|
| 28 |
-
|
| 29 |
|
| 30 |
def get_noise(self, path) -> torch.Tensor:
|
| 31 |
if path is None:
|
|
@@ -34,13 +31,15 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
| 34 |
if noise.shape[0] > 1:
|
| 35 |
noise = noise.mean(0, keepdim=True)
|
| 36 |
if sr != self.input_freq:
|
| 37 |
-
noise = taF.resample(noise,sr, self.input_freq)
|
| 38 |
return noise
|
| 39 |
|
| 40 |
-
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
|
| 41 |
-
assert
|
|
|
|
|
|
|
| 42 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
| 43 |
-
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
|
| 44 |
noise_power = noise.norm(p=2)
|
| 45 |
signal_power = waveform.norm(p=2)
|
| 46 |
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
|
|
@@ -49,14 +48,28 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
| 49 |
noisy_waveform = (scale * waveform + noise) / 2
|
| 50 |
return noisy_waveform
|
| 51 |
|
| 52 |
-
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
| 53 |
-
|
| 54 |
-
waveform = self.resample(waveform)
|
| 55 |
-
except:
|
| 56 |
-
print("oops")
|
| 57 |
waveform = self.preprocess_waveform(waveform)
|
| 58 |
if self.noise is not None:
|
| 59 |
waveform = self.add_noise(waveform)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
spec = self.audio_to_spectrogram(waveform)
|
| 61 |
|
| 62 |
# Spectrogram augmentation
|
|
@@ -67,14 +80,11 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
| 67 |
|
| 68 |
|
| 69 |
class WaveformPreprocessing(torch.nn.Module):
|
| 70 |
-
|
| 71 |
-
def __init__(self, expected_sample_length:int):
|
| 72 |
super().__init__()
|
| 73 |
self.expected_sample_length = expected_sample_length
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
| 77 |
-
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
| 78 |
# Take out extra channels
|
| 79 |
if waveform.shape[0] > 1:
|
| 80 |
waveform = waveform.mean(0, keepdim=True)
|
|
@@ -83,30 +93,34 @@ class WaveformPreprocessing(torch.nn.Module):
|
|
| 83 |
waveform = self._rectify_duration(waveform)
|
| 84 |
return waveform
|
| 85 |
|
| 86 |
-
|
| 87 |
-
def _rectify_duration(self,waveform:torch.Tensor):
|
| 88 |
expected_samples = self.expected_sample_length
|
| 89 |
sample_count = waveform.shape[1]
|
| 90 |
if expected_samples == sample_count:
|
| 91 |
return waveform
|
| 92 |
elif expected_samples > sample_count:
|
| 93 |
pad_amount = expected_samples - sample_count
|
| 94 |
-
return torch.nn.functional.pad(
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
-
return waveform[
|
| 97 |
|
| 98 |
|
| 99 |
-
class AudioToSpectrogram
|
| 100 |
def __init__(
|
| 101 |
self,
|
| 102 |
sample_rate=16000,
|
| 103 |
):
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
self.to_db = taT.AmplitudeToDB()
|
| 108 |
|
| 109 |
-
def
|
| 110 |
spectrogram = self.spec(waveform)
|
| 111 |
spectrogram = self.to_db(spectrogram)
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from torchaudio import transforms as taT, functional as taF
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
+
|
| 7 |
+
class WaveformTrainingPipeline(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
input_freq=16000,
|
| 11 |
+
resample_freq=16000,
|
| 12 |
+
expected_duration=6,
|
| 13 |
+
snr_mean=6.0,
|
| 14 |
+
noise_path=None,
|
| 15 |
+
):
|
| 16 |
super().__init__()
|
| 17 |
self.input_freq = input_freq
|
| 18 |
self.snr_mean = snr_mean
|
|
|
|
| 19 |
self.noise = self.get_noise(noise_path)
|
| 20 |
+
self.resample_frequency = resample_freq
|
| 21 |
+
self.resample = taT.Resample(input_freq, resample_freq)
|
| 22 |
+
|
| 23 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
| 24 |
+
resample_freq * expected_duration
|
| 25 |
)
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def get_noise(self, path) -> torch.Tensor:
|
| 28 |
if path is None:
|
|
|
|
| 31 |
if noise.shape[0] > 1:
|
| 32 |
noise = noise.mean(0, keepdim=True)
|
| 33 |
if sr != self.input_freq:
|
| 34 |
+
noise = taF.resample(noise, sr, self.input_freq)
|
| 35 |
return noise
|
| 36 |
|
| 37 |
+
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
assert (
|
| 39 |
+
self.noise is not None
|
| 40 |
+
), "Cannot add noise because a noise file was not provided."
|
| 41 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
| 42 |
+
noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
|
| 43 |
noise_power = noise.norm(p=2)
|
| 44 |
signal_power = waveform.norm(p=2)
|
| 45 |
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
|
|
|
|
| 48 |
noisy_waveform = (scale * waveform + noise) / 2
|
| 49 |
return noisy_waveform
|
| 50 |
|
| 51 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
waveform = self.resample(waveform)
|
|
|
|
|
|
|
|
|
|
| 53 |
waveform = self.preprocess_waveform(waveform)
|
| 54 |
if self.noise is not None:
|
| 55 |
waveform = self.add_noise(waveform)
|
| 56 |
+
return waveform
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
| 60 |
+
def __init__(
|
| 61 |
+
self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
|
| 62 |
+
):
|
| 63 |
+
super().__init__(*args, **kwargs)
|
| 64 |
+
self.mask_count = mask_count
|
| 65 |
+
self.audio_to_spectrogram = AudioToSpectrogram(
|
| 66 |
+
sample_rate=self.resample_frequency,
|
| 67 |
+
)
|
| 68 |
+
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
| 69 |
+
self.time_mask = taT.TimeMasking(time_mask_size)
|
| 70 |
+
|
| 71 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
waveform = super().forward(waveform)
|
| 73 |
spec = self.audio_to_spectrogram(waveform)
|
| 74 |
|
| 75 |
# Spectrogram augmentation
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
class WaveformPreprocessing(torch.nn.Module):
|
| 83 |
+
def __init__(self, expected_sample_length: int):
|
|
|
|
| 84 |
super().__init__()
|
| 85 |
self.expected_sample_length = expected_sample_length
|
|
|
|
| 86 |
|
| 87 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 88 |
# Take out extra channels
|
| 89 |
if waveform.shape[0] > 1:
|
| 90 |
waveform = waveform.mean(0, keepdim=True)
|
|
|
|
| 93 |
waveform = self._rectify_duration(waveform)
|
| 94 |
return waveform
|
| 95 |
|
| 96 |
+
def _rectify_duration(self, waveform: torch.Tensor):
|
|
|
|
| 97 |
expected_samples = self.expected_sample_length
|
| 98 |
sample_count = waveform.shape[1]
|
| 99 |
if expected_samples == sample_count:
|
| 100 |
return waveform
|
| 101 |
elif expected_samples > sample_count:
|
| 102 |
pad_amount = expected_samples - sample_count
|
| 103 |
+
return torch.nn.functional.pad(
|
| 104 |
+
waveform, (0, pad_amount), mode="constant", value=0.0
|
| 105 |
+
)
|
| 106 |
else:
|
| 107 |
+
return waveform[:, :expected_samples]
|
| 108 |
|
| 109 |
|
| 110 |
+
class AudioToSpectrogram:
|
| 111 |
def __init__(
|
| 112 |
self,
|
| 113 |
sample_rate=16000,
|
| 114 |
):
|
| 115 |
+
self.spec = taT.MelSpectrogram(
|
| 116 |
+
sample_rate=sample_rate, n_mels=128, n_fft=1024
|
| 117 |
+
) # Note: this doesn't work on mps right now.
|
| 118 |
self.to_db = taT.AmplitudeToDB()
|
| 119 |
|
| 120 |
+
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 121 |
spectrogram = self.spec(waveform)
|
| 122 |
spectrogram = self.to_db(spectrogram)
|
| 123 |
+
|
| 124 |
+
# Normalize
|
| 125 |
+
spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
|
| 126 |
+
return spectrogram
|
preprocessing/preprocess.py
CHANGED
|
@@ -3,7 +3,9 @@ import numpy as np
|
|
| 3 |
import re
|
| 4 |
import json
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
import os
|
|
|
|
| 7 |
import torchaudio
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
|
@@ -95,7 +97,6 @@ def vectorize_label_probs(
|
|
| 95 |
for k, v in labels.items():
|
| 96 |
item_vec = (unique_labels == k) * v
|
| 97 |
label_vec += item_vec
|
| 98 |
-
lv_cache = label_vec.copy()
|
| 99 |
label_vec[label_vec < 0] = 0
|
| 100 |
label_vec /= label_vec.sum()
|
| 101 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
|
@@ -113,49 +114,70 @@ def vectorize_multi_label(
|
|
| 113 |
return probs
|
| 114 |
|
| 115 |
|
| 116 |
-
def
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
if __name__ == "__main__":
|
| 151 |
-
|
| 152 |
-
df = pd.read_csv("data/songs.csv")
|
| 153 |
-
l = links["link"].str.strip()
|
| 154 |
-
l = l.apply(lambda url: url if "http" in url else np.nan)
|
| 155 |
-
l = l.dropna()
|
| 156 |
-
df["Sample"].update(l)
|
| 157 |
-
addna = lambda url: url if type(url) == str and "http" in url else np.nan
|
| 158 |
-
df["Sample"] = df["Sample"].apply(addna)
|
| 159 |
-
is_valid = validate_audio(df["Sample"], "data/samples")
|
| 160 |
-
df["valid"] = is_valid
|
| 161 |
-
df.to_csv("data/songs_validated.csv")
|
|
|
|
| 3 |
import re
|
| 4 |
import json
|
| 5 |
from pathlib import Path
|
| 6 |
+
import glob
|
| 7 |
import os
|
| 8 |
+
import shutil
|
| 9 |
import torchaudio
|
| 10 |
import torch
|
| 11 |
from tqdm import tqdm
|
|
|
|
| 97 |
for k, v in labels.items():
|
| 98 |
item_vec = (unique_labels == k) * v
|
| 99 |
label_vec += item_vec
|
|
|
|
| 100 |
label_vec[label_vec < 0] = 0
|
| 101 |
label_vec /= label_vec.sum()
|
| 102 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
|
|
|
| 114 |
return probs
|
| 115 |
|
| 116 |
|
| 117 |
+
def sort_yt_files(
|
| 118 |
+
aliases_path="data/dance_aliases.json",
|
| 119 |
+
all_dances_folder="data/best-ballroom-music",
|
| 120 |
+
original_location="data/yt-ballroom-music/",
|
| 121 |
+
):
|
| 122 |
+
def normalize_string(s):
|
| 123 |
+
# Lowercase string and remove special characters
|
| 124 |
+
return re.sub(r"\W+", "", s.lower())
|
| 125 |
+
|
| 126 |
+
with open(aliases_path, "r") as f:
|
| 127 |
+
dances = json.load(f)
|
| 128 |
+
|
| 129 |
+
# Normalize the dance inputs and aliases
|
| 130 |
+
normalized_dances = {
|
| 131 |
+
normalize_string(dance_id): [normalize_string(alias) for alias in aliases]
|
| 132 |
+
for dance_id, aliases in dances.items()
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# For every wav file in the target folder
|
| 136 |
+
bad_files = []
|
| 137 |
+
progress_bar = tqdm(os.listdir(all_dances_folder), unit="files moved")
|
| 138 |
+
for file_name in progress_bar:
|
| 139 |
+
if file_name.endswith(".wav"):
|
| 140 |
+
# check if the normalized wav file name contains the normalized dance alias
|
| 141 |
+
normalized_file_name = normalize_string(file_name)
|
| 142 |
+
|
| 143 |
+
matching_dance_ids = [
|
| 144 |
+
dance_id
|
| 145 |
+
for dance_id, aliases in normalized_dances.items()
|
| 146 |
+
if any(alias in normalized_file_name for alias in aliases)
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
if len(matching_dance_ids) == 0:
|
| 150 |
+
# See if the dance is in the path
|
| 151 |
+
original_filename = file_name.replace(".wav", "")
|
| 152 |
+
matches = glob.glob(
|
| 153 |
+
os.path.join(original_location, "**", original_filename),
|
| 154 |
+
recursive=True,
|
| 155 |
+
)
|
| 156 |
+
if len(matches) == 1:
|
| 157 |
+
normalized_file_name = normalize_string(matches[0])
|
| 158 |
+
matching_dance_ids = [
|
| 159 |
+
dance_id
|
| 160 |
+
for dance_id, aliases in normalized_dances.items()
|
| 161 |
+
if any(alias in normalized_file_name for alias in aliases)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
if "swz" in matching_dance_ids and "vwz" in matching_dance_ids:
|
| 165 |
+
matching_dance_ids.remove("swz")
|
| 166 |
+
if len(matching_dance_ids) > 1 and "lhp" in matching_dance_ids:
|
| 167 |
+
matching_dance_ids.remove("lhp")
|
| 168 |
+
|
| 169 |
+
if len(matching_dance_ids) != 1:
|
| 170 |
+
bad_files.append(file_name)
|
| 171 |
+
progress_bar.set_description(f"bad files: {len(bad_files)}")
|
| 172 |
+
continue
|
| 173 |
+
dst = os.path.join("data", "ballroom-songs", matching_dance_ids[0].upper())
|
| 174 |
+
os.makedirs(dst, exist_ok=True)
|
| 175 |
+
filepath = os.path.join(all_dances_folder, file_name)
|
| 176 |
+
shutil.copy(filepath, os.path.join(dst, file_name))
|
| 177 |
+
|
| 178 |
+
with open("data/bad_files.json", "w") as f:
|
| 179 |
+
json.dump(bad_files, f)
|
| 180 |
|
| 181 |
|
| 182 |
if __name__ == "__main__":
|
| 183 |
+
sort_yt_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
import torchaudio
|
| 2 |
-
import numpy as np
|
| 3 |
-
from audio_utils import play_audio
|
| 4 |
-
from preprocessing.dataset import SongDataset
|
| 5 |
-
|
| 6 |
-
def test_audio_splitting():
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
audio_paths = ["data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav"]
|
| 11 |
-
labels = [np.array([1,0,1,0])]
|
| 12 |
-
whole_song, sr = torchaudio.load("data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav")
|
| 13 |
-
|
| 14 |
-
ds = SongDataset(audio_paths, labels)
|
| 15 |
-
song_parts = (ds._waveform_from_index(i) for i in range(len(ds)))
|
| 16 |
-
print("Sample Parts")
|
| 17 |
-
for part in song_parts:
|
| 18 |
-
play_audio(part,sr)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
print("Whole Sample")
|
| 22 |
-
play_audio(whole_song,sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_datasets.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import set_path
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
set_path()
|
| 5 |
+
from preprocessing.dataset import PipelinedDataset, BestBallroomDataset, SongDataset
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_preprocess_dataset():
|
| 10 |
+
dataset = BestBallroomDataset()
|
| 11 |
+
dataset = PipelinedDataset(dataset, lambda x: x * 0.0)
|
| 12 |
+
assert isinstance(dataset._data.song_dataset, SongDataset)
|
| 13 |
+
assert hasattr(dataset, "feature_extractor")
|
| 14 |
+
features, _ = dataset[0]
|
| 15 |
+
assert np.unique(features.numpy())[0] == 0.0
|
| 16 |
+
with pytest.raises(AttributeError):
|
| 17 |
+
dataset.foo
|
tests/test_pipelines.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import set_path
|
| 2 |
+
|
| 3 |
+
set_path()
|
| 4 |
+
from preprocessing.dataset import BestBallroomDataset
|
| 5 |
+
from preprocessing.pipelines import SpectrogramTrainingPipeline
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_spectrogram_training_pipeline():
|
| 9 |
+
ds = BestBallroomDataset()
|
| 10 |
+
pipeline = SpectrogramTrainingPipeline()
|
| 11 |
+
waveform, _ = ds[0]
|
| 12 |
+
out = pipeline(waveform)
|
| 13 |
+
assert len(out.shape) == 3
|
tests/utils.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Add parent directory to Python path
|
| 6 |
+
def set_path():
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
train.py
CHANGED
|
@@ -1,49 +1,16 @@
|
|
| 1 |
-
from torch.utils.data import DataLoader
|
| 2 |
-
import pandas as pd
|
| 3 |
from typing import Callable
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import SubsetRandomSampler
|
| 6 |
-
from sklearn.model_selection import KFold
|
| 7 |
-
import pytorch_lightning as pl
|
| 8 |
-
from pytorch_lightning import callbacks as cb
|
| 9 |
-
from models.utils import LabelWeightedBCELoss
|
| 10 |
-
from models.audio_spectrogram_transformer import (
|
| 11 |
-
train as train_audio_spectrogram_transformer,
|
| 12 |
-
get_id_label_mapping,
|
| 13 |
-
)
|
| 14 |
-
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
| 15 |
-
from preprocessing.preprocess import get_examples
|
| 16 |
-
from models.residual import ResidualDancer, TrainingEnvironment
|
| 17 |
-
from models.decision_tree import DanceTreeClassifier, features_from_path
|
| 18 |
import yaml
|
| 19 |
-
from preprocessing.dataset import (
|
| 20 |
-
DanceDataModule,
|
| 21 |
-
WaveformSongDataset,
|
| 22 |
-
HuggingFaceWaveformSongDataset,
|
| 23 |
-
)
|
| 24 |
-
from torch.utils.data import random_split
|
| 25 |
-
import numpy as np
|
| 26 |
-
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
| 27 |
from argparse import ArgumentParser
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
import torch
|
| 31 |
-
from torch import nn
|
| 32 |
-
from sklearn.utils.class_weight import compute_class_weight
|
| 33 |
|
| 34 |
|
| 35 |
def get_training_fn(id: str) -> Callable:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
case "ast_hf":
|
| 40 |
-
return train_ast
|
| 41 |
-
case "residual_dancer":
|
| 42 |
-
return train_model
|
| 43 |
-
case "decision_tree":
|
| 44 |
-
return train_decision_tree
|
| 45 |
-
case _:
|
| 46 |
-
raise Exception(f"Couldn't find a training function for '{id}'.")
|
| 47 |
|
| 48 |
|
| 49 |
def get_config(filepath: str) -> dict:
|
|
@@ -52,141 +19,6 @@ def get_config(filepath: str) -> dict:
|
|
| 52 |
return config
|
| 53 |
|
| 54 |
|
| 55 |
-
def cross_validation(config, k=5):
|
| 56 |
-
df = pd.read_csv("data/songs.csv")
|
| 57 |
-
g_config = config["global"]
|
| 58 |
-
batch_size = config["data_module"]["batch_size"]
|
| 59 |
-
x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"])
|
| 60 |
-
dataset = SongDataset(x, y)
|
| 61 |
-
splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"])
|
| 62 |
-
trainer = pl.Trainer(accelerator=g_config["device"])
|
| 63 |
-
for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)):
|
| 64 |
-
print(f"Fold {fold+1}")
|
| 65 |
-
model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
|
| 66 |
-
train_env = TrainingEnvironment(model, nn.BCELoss())
|
| 67 |
-
train_sampler = SubsetRandomSampler(train_idx)
|
| 68 |
-
test_sampler = SubsetRandomSampler(val_idx)
|
| 69 |
-
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
| 70 |
-
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
|
| 71 |
-
trainer.fit(train_env, train_loader)
|
| 72 |
-
trainer.test(train_env, test_loader)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def train_model(config: dict):
|
| 76 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
| 77 |
-
DEVICE = config["global"]["device"]
|
| 78 |
-
SEED = config["global"]["seed"]
|
| 79 |
-
pl.seed_everything(SEED, workers=True)
|
| 80 |
-
data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"])
|
| 81 |
-
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
| 82 |
-
label_weights = data.get_label_weights().to(DEVICE)
|
| 83 |
-
criterion = LabelWeightedBCELoss(
|
| 84 |
-
label_weights
|
| 85 |
-
) # nn.CrossEntropyLoss(label_weights)
|
| 86 |
-
train_env = TrainingEnvironment(model, criterion, config)
|
| 87 |
-
callbacks = [
|
| 88 |
-
# cb.LearningRateFinder(update_attr=True),
|
| 89 |
-
cb.EarlyStopping("val/loss", patience=5),
|
| 90 |
-
cb.StochasticWeightAveraging(1e-2),
|
| 91 |
-
cb.RichProgressBar(),
|
| 92 |
-
cb.DeviceStatsMonitor(),
|
| 93 |
-
]
|
| 94 |
-
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
| 95 |
-
trainer.fit(train_env, datamodule=data)
|
| 96 |
-
trainer.test(train_env, datamodule=data)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def train_ast(config: dict):
|
| 100 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
| 101 |
-
DEVICE = config["global"]["device"]
|
| 102 |
-
SEED = config["global"]["seed"]
|
| 103 |
-
dataset_kwargs = config["data_module"]["dataset_kwargs"]
|
| 104 |
-
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
| 105 |
-
train_proportion = 1.0 - test_proportion
|
| 106 |
-
song_data_path = "data/songs_cleaned.csv"
|
| 107 |
-
song_audio_path = "data/samples"
|
| 108 |
-
pl.seed_everything(SEED, workers=True)
|
| 109 |
-
|
| 110 |
-
df = pd.read_csv(song_data_path)
|
| 111 |
-
x, y = get_examples(
|
| 112 |
-
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
| 113 |
-
)
|
| 114 |
-
train_i, test_i = random_split(
|
| 115 |
-
np.arange(len(x)), [train_proportion, test_proportion]
|
| 116 |
-
)
|
| 117 |
-
train_ds = HuggingFaceWaveformSongDataset(
|
| 118 |
-
x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000
|
| 119 |
-
)
|
| 120 |
-
test_ds = HuggingFaceWaveformSongDataset(
|
| 121 |
-
x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000
|
| 122 |
-
)
|
| 123 |
-
train_audio_spectrogram_transformer(
|
| 124 |
-
TARGET_CLASSES, train_ds, test_ds, device=DEVICE
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def train_ast_lightning(config: dict):
|
| 129 |
-
"""
|
| 130 |
-
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
| 131 |
-
"""
|
| 132 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
| 133 |
-
DEVICE = config["global"]["device"]
|
| 134 |
-
SEED = config["global"]["seed"]
|
| 135 |
-
pl.seed_everything(SEED, workers=True)
|
| 136 |
-
data = DanceDataModule(
|
| 137 |
-
target_classes=TARGET_CLASSES,
|
| 138 |
-
dataset_cls=WaveformSongDataset,
|
| 139 |
-
**config["data_module"],
|
| 140 |
-
)
|
| 141 |
-
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
| 142 |
-
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 143 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
| 144 |
-
|
| 145 |
-
model = AutoModelForAudioClassification.from_pretrained(
|
| 146 |
-
model_checkpoint,
|
| 147 |
-
num_labels=len(label2id),
|
| 148 |
-
label2id=label2id,
|
| 149 |
-
id2label=id2label,
|
| 150 |
-
ignore_mismatched_sizes=True,
|
| 151 |
-
).to(DEVICE)
|
| 152 |
-
label_weights = data.get_label_weights().to(DEVICE)
|
| 153 |
-
criterion = LabelWeightedBCELoss(
|
| 154 |
-
label_weights
|
| 155 |
-
) # nn.CrossEntropyLoss(label_weights)
|
| 156 |
-
train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config)
|
| 157 |
-
callbacks = [
|
| 158 |
-
# cb.LearningRateFinder(update_attr=True),
|
| 159 |
-
cb.EarlyStopping("val/loss", patience=5),
|
| 160 |
-
cb.StochasticWeightAveraging(1e-2),
|
| 161 |
-
cb.RichProgressBar(),
|
| 162 |
-
]
|
| 163 |
-
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
| 164 |
-
trainer.fit(train_env, datamodule=data)
|
| 165 |
-
trainer.test(train_env, datamodule=data)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def train_decision_tree(config: dict):
|
| 169 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
| 170 |
-
DEVICE = config["global"]["device"]
|
| 171 |
-
SEED = config["global"]["seed"]
|
| 172 |
-
song_data_path = config["data_module"]["song_data_path"]
|
| 173 |
-
song_audio_path = config["data_module"]["song_audio_path"]
|
| 174 |
-
pl.seed_everything(SEED, workers=True)
|
| 175 |
-
|
| 176 |
-
df = pd.read_csv(song_data_path)
|
| 177 |
-
x, y = get_examples(
|
| 178 |
-
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
| 179 |
-
)
|
| 180 |
-
# Convert y back to string classes
|
| 181 |
-
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
| 182 |
-
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
|
| 183 |
-
train_paths, train_y = x[train_i], y[train_i]
|
| 184 |
-
train_x = features_from_path(train_paths)
|
| 185 |
-
model = DanceTreeClassifier(device=DEVICE)
|
| 186 |
-
model.fit(train_x, train_y)
|
| 187 |
-
model.save()
|
| 188 |
-
|
| 189 |
-
|
| 190 |
if __name__ == "__main__":
|
| 191 |
parser = ArgumentParser(
|
| 192 |
description="Trains models on the dance dataset and saves weights."
|
|
@@ -198,6 +30,7 @@ if __name__ == "__main__":
|
|
| 198 |
)
|
| 199 |
args = parser.parse_args()
|
| 200 |
config = get_config(args.config)
|
| 201 |
-
|
| 202 |
-
|
|
|
|
| 203 |
train(config)
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Callable
|
| 2 |
+
import importlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from argparse import ArgumentParser
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
+
ROOT_DIR = os.path.basename(os.path.dirname(__file__))
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def get_training_fn(id: str) -> Callable:
|
| 11 |
+
module_name, fn_name = id.rsplit(".", 1)
|
| 12 |
+
module = importlib.import_module("models." + module_name, ROOT_DIR)
|
| 13 |
+
return getattr(module, fn_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def get_config(filepath: str) -> dict:
|
|
|
|
| 19 |
return config
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if __name__ == "__main__":
|
| 23 |
parser = ArgumentParser(
|
| 24 |
description="Trains models on the dance dataset and saves weights."
|
|
|
|
| 30 |
)
|
| 31 |
args = parser.parse_args()
|
| 32 |
config = get_config(args.config)
|
| 33 |
+
training_fn_path = config["training_fn"]
|
| 34 |
+
print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
|
| 35 |
+
train = get_training_fn(training_fn_path)
|
| 36 |
train(config)
|