| import os |
| import random |
|
|
| import torch |
| import torchaudio |
| import torchaudio.transforms as AT |
| import csv |
| import numpy as np |
| import librosa |
| import pandas as pd |
| import laion_clap |
| import soundfile as sf |
| from model.CLAPSep import LightningModule |
| from model.CLAPSep_decoder import HTSAT_Decoder |
| import argparse |
| import pytorch_lightning as pl |
| from helpers import utils as local_utils |
|
|
|
|
| class AudioCapsTest(torch.utils.data.Dataset): |
|
|
| def __init__(self, eval_csv, input_dir, sr=32000, |
| resample_rate=48000): |
| self.data_path = input_dir |
|
|
| self.data_names = [] |
| self.data_caps = [] |
| self.noise_names = [] |
| self.noise_caps = [] |
| with open(eval_csv, 'r') as d: |
| reader = csv.reader(d, skipinitialspace=True) |
| next(reader) |
| for row in reader: |
| self.data_names.append(row[0]) |
| self.data_caps.append(row[1]) |
| self.noise_names.append(row[2]) |
| self.noise_caps.append(row[3]) |
|
|
| if resample_rate is not None: |
| self.resampler = AT.Resample(sr, resample_rate) |
| self.sr = sr |
| self.resample_rate = resample_rate |
| else: |
| self.sr = sr |
|
|
| def __len__(self): |
| return len(self.data_names) |
|
|
| def load_wav(self, path): |
| max_length = self.sr * 10 |
| wav = librosa.core.load(path, sr=self.sr)[0] |
| if len(wav) > max_length: |
| wav = wav[0:max_length] |
|
|
| |
| if len(wav) < max_length: |
| |
| wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
| return wav |
|
|
| def __getitem__(self, idx): |
|
|
| tgt_name = self.data_names[idx] |
| noise_name = self.noise_names[idx] |
| tgt_cap = self.data_caps[idx] |
| neg_cap = self.noise_caps[idx] |
|
|
| assert noise_name != tgt_name |
| snr = torch.ones((1,)) * 0 |
| tgt = torch.tensor(self.load_wav(os.path.join(self.data_path, tgt_name))).unsqueeze(0) |
| noise = torch.tensor(self.load_wav(os.path.join(self.data_path, noise_name))).unsqueeze(0) |
| mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) |
|
|
| max_value = torch.max(torch.abs(mixed)) |
| if max_value > 1: |
| tgt *= 0.9 / max_value |
| mixed *= 0.9 / max_value |
| |
| tgt = tgt.squeeze() |
| mixed = mixed.squeeze() |
| |
| return mixed, self.resampler(mixed), tgt_cap, neg_cap, tgt |
|
|
|
|
|
|
| def main(args): |
| torch.set_float32_matmul_precision('highest') |
| |
| |
| data_test = AudioCapsTest(eval_csv=args.eval_csv, |
| input_dir=args.input_dir, |
| sr=args.sample_rate, |
| resample_rate=48000) |
|
|
| test_loader = torch.utils.data.DataLoader(data_test, |
| batch_size=1, |
| num_workers=1, |
| pin_memory=True, |
| shuffle=False) |
|
|
| clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') |
| clap_model.load_ckpt(args.clap_path) |
| decoder = HTSAT_Decoder(**args.model) |
| lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'], |
| use_lora=args.lora, |
| rank=args.lora_rank, |
| nfft=args.nfft) |
| distributed_backend = "ddp" |
| trainer = pl.Trainer( |
| default_root_dir=os.path.join(args.exp_dir, 'checkpoint'), |
| devices=args.gpu_ids if args.use_cuda else "auto", |
| accelerator="gpu" if args.use_cuda else "cpu", |
| benchmark=False, |
| gradient_clip_val=5.0, |
| precision='bf16-mixed', |
| limit_train_batches=1.0, |
| max_epochs=args.epochs, |
| strategy=distributed_backend, |
| logger=False |
| ) |
|
|
| |
| |
|
|
| |
| |
| trainer.test(model=lightning_module, dataloaders=test_loader, ckpt_path=args.ckpt_path) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument('exp_dir', type=str, |
| default='experiments', |
| help="Path to save checkpoints and logs.") |
| |
| parser.add_argument('--sample_rate', type=int, default=32000) |
| parser.add_argument('--ckpt_path', type=str, default='') |
| parser.add_argument('--eval_csv', type=str, default='') |
| parser.add_argument('--input_dir', type=str, default='') |
|
|
| parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', |
| help="Whether to use cuda") |
| parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, |
| help="List of GPU ids used for training. " |
| "Eg., --gpu_ids 2 4. All GPUs are used by default.") |
|
|
| args = parser.parse_args() |
|
|
| |
| pl.seed_everything(114514) |
| |
| if not os.path.exists(args.exp_dir): |
| os.makedirs(args.exp_dir) |
|
|
| |
| params = local_utils.Params(os.path.join(args.exp_dir, 'config.json')) |
| for k, v in params.__dict__.items(): |
| vars(args)[k] = v |
| main(args) |
|
|