Spaces:
Build error
Build error
File size: 5,112 Bytes
15ac91d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import logging
import h5py
import soundfile
import librosa
import numpy as np
import pandas as pd
from scipy import stats
import datetime
import pickle
def create_folder(fd):
if not os.path.exists(fd):
os.makedirs(fd)
def get_filename(path):
path = os.path.realpath(path)
na_ext = path.split('/')[-1]
na = os.path.splitext(na_ext)[0]
return na
def get_sub_filepaths(folder):
paths = []
for root, dirs, files in os.walk(folder):
for name in files:
path = os.path.join(root, name)
paths.append(path)
return paths
def create_logging(log_dir, filemode):
create_folder(log_dir)
i1 = 0
while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))):
i1 += 1
log_path = os.path.join(log_dir, '{:04d}.log'.format(i1))
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S',
filename=log_path,
filemode=filemode)
# Print to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
return logging
def read_metadata(csv_path, classes_num, id_to_ix):
"""Read metadata of AudioSet from a csv file.
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path, 'r') as fr:
lines = fr.readlines()
lines = lines[3:] # Remove heads
audios_num = len(lines)
targets = np.zeros((audios_num, classes_num), dtype=np.bool)
audio_names = []
for n, line in enumerate(lines):
items = line.split(', ')
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading
label_ids = items[3].split('"')[1].split(',')
audio_names.append(audio_name)
# Target
for id in label_ids:
ix = id_to_ix[id]
targets[n, ix] = 1
meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
return meta_dict
def float32_to_int16(x):
assert np.max(np.abs(x)) <= 1.2
x = np.clip(x, -1, 1)
return (x * 32767.).astype(np.int16)
def int16_to_float32(x):
return (x / 32767.).astype(np.float32)
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0)
else:
return x[0 : audio_length]
def d_prime(auc):
d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
return d_prime
class Mixup(object):
def __init__(self, mixup_alpha, random_seed=1234):
"""Mixup coefficient generator.
"""
self.mixup_alpha = mixup_alpha
self.random_state = np.random.RandomState(random_seed)
def get_lambda(self, batch_size):
"""Get mixup random coefficients.
Args:
batch_size: int
Returns:
mixup_lambdas: (batch_size,)
"""
mixup_lambdas = []
for n in range(0, batch_size, 2):
lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
mixup_lambdas.append(lam)
mixup_lambdas.append(1. - lam)
return np.array(mixup_lambdas)
class StatisticsContainer(object):
def __init__(self, statistics_path):
"""Contain statistics of different training iterations.
"""
self.statistics_path = statistics_path
self.backup_statistics_path = '{}_{}.pkl'.format(
os.path.splitext(self.statistics_path)[0],
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
self.statistics_dict = {'bal': [], 'test': []}
def append(self, iteration, statistics, data_type):
statistics['iteration'] = iteration
self.statistics_dict[data_type].append(statistics)
def dump(self):
pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb'))
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb'))
logging.info(' Dump statistics to {}'.format(self.statistics_path))
logging.info(' Dump statistics to {}'.format(self.backup_statistics_path))
def load_state_dict(self, resume_iteration):
self.statistics_dict = pickle.load(open(self.statistics_path, 'rb'))
resume_statistics_dict = {'bal': [], 'test': []}
for key in self.statistics_dict.keys():
for statistics in self.statistics_dict[key]:
if statistics['iteration'] <= resume_iteration:
resume_statistics_dict[key].append(statistics)
self.statistics_dict = resume_statistics_dict |