Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import sys | |
from typing import List | |
import librosa | |
import numpy as np | |
import soundfile | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from utils.binary import DatasetReader | |
class CustomDataset(Dataset): | |
def __init__(self, | |
data_list_path, | |
processor, | |
mono=True, | |
language=None, | |
timestamps=False, | |
sample_rate=16000, | |
min_duration=0.5, | |
max_duration=30, | |
augment_config_path=None): | |
""" | |
Args: | |
data_list_path: | |
processor: Whisper | |
mono: True | |
language: | |
timestamps: | |
sample_rate: 16000 | |
min_duration: 0.5s | |
max_duration: 30s | |
augment_config_path: | |
""" | |
super(CustomDataset, self).__init__() | |
assert min_duration >= 0.5, f"min_duration 0.5:{min_duration}" | |
assert max_duration <= 30, f"max_duration 30:{max_duration}" | |
self.data_list_path = data_list_path | |
self.processor = processor | |
self.data_list_path = data_list_path | |
self.sample_rate = sample_rate | |
self.mono = mono | |
self.language = language | |
self.timestamps = timestamps | |
self.min_duration = min_duration | |
self.max_duration = max_duration | |
self.vocab = self.processor.tokenizer.get_vocab() | |
self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1 | |
self.startoftranscript = self.vocab['<|startoftranscript|>'] | |
self.endoftext = self.vocab['<|endoftext|>'] | |
self.nocaptions = self.vocab['<|nocaptions|>'] | |
self.data_list: List[dict] = [] | |
# | |
self._load_data_list() | |
# | |
self.augment_configs = None | |
self.noises_path = None | |
self.speed_rates = None | |
if augment_config_path: | |
with open(augment_config_path, 'r', encoding='utf-8') as f: | |
self.augment_configs = json.load(f) | |
# | |
def _load_data_list(self): | |
if self.data_list_path.endswith(".header"): | |
# | |
self.dataset_reader = DatasetReader(data_header_path=self.data_list_path, | |
min_duration=self.min_duration, | |
max_duration=self.max_duration) | |
self.data_list = self.dataset_reader.get_keys() | |
else: | |
# | |
with open(self.data_list_path, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
self.data_list = [] | |
for line in tqdm(lines, desc=''): | |
if isinstance(line, str): | |
line = json.loads(line) | |
if not isinstance(line, dict): continue | |
# | |
if line["duration"] < self.min_duration: | |
continue | |
if self.max_duration != -1 and line["duration"] > self.max_duration: | |
continue | |
self.data_list.append(dict(line)) | |
# | |
def _get_list_data(self, idx): | |
if self.data_list_path.endswith(".header"): | |
data_list = self.dataset_reader.get_data(self.data_list[idx]) | |
else: | |
data_list = self.data_list[idx] | |
# | |
audio_file = data_list["audio"]['path'] | |
transcript = data_list["sentences"] if self.timestamps else data_list["sentence"] | |
language = data_list["language"] if 'language' in data_list.keys() else None | |
if 'start_time' not in data_list["audio"].keys(): | |
sample, sample_rate = soundfile.read(audio_file, dtype='float32') | |
else: | |
start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"] | |
# | |
sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time) | |
sample = sample.T | |
# | |
if self.mono: | |
sample = librosa.to_mono(sample) | |
# | |
if self.augment_configs: | |
sample, sample_rate = self.augment(sample, sample_rate) | |
# | |
if self.sample_rate != sample_rate: | |
sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate) | |
return sample, sample_rate, transcript, language | |
def _load_timestamps_transcript(self, transcript: List[dict]): | |
assert isinstance(transcript, list), f"transcript list:{type(transcript)}" | |
data = dict() | |
labels = self.processor.tokenizer.prefix_tokens[:3] | |
for t in transcript: | |
# | |
start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01 | |
start = self.timestamp_begin + round(start * 100) // 2 | |
end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01 | |
end = self.timestamp_begin + round(end * 100) // 2 | |
label = self.processor(text=t['text']).input_ids[4:-1] | |
labels.extend([start]) | |
labels.extend(label) | |
labels.extend([end]) | |
data['labels'] = labels + [self.endoftext] | |
return data | |
def __getitem__(self, idx): | |
try: | |
# | |
sample, sample_rate, transcript, language = self._get_list_data(idx=idx) | |
# | |
self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language) | |
if len(transcript) > 0: | |
# | |
if self.timestamps: | |
data = self._load_timestamps_transcript(transcript=transcript) | |
# | |
data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features | |
else: | |
# | |
data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript) | |
else: | |
# | |
data = self.processor(audio=sample, sampling_rate=self.sample_rate) | |
data['labels'] = [self.startoftranscript, self.nocaptions, self.endoftext] | |
return data | |
except Exception as e: | |
print(f'idx:{idx} error - {e}', file=sys.stderr) | |
return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
def __len__(self): | |
return len(self.data_list) | |
# | |
def slice_from_file(file, start, end): | |
sndfile = soundfile.SoundFile(file) | |
sample_rate = sndfile.samplerate | |
duration = round(float(len(sndfile)) / sample_rate, 3) | |
start = round(start, 3) | |
end = round(end, 3) | |
# | |
if start < 0.0: start += duration | |
if end < 0.0: end += duration | |
# | |
if start < 0.0: start = 0.0 | |
if end > duration: end = duration | |
if end < 0.0: | |
raise ValueError("(%f s)" % end) | |
if start > end: | |
raise ValueError("(%f s)(%f s)" % (start, end)) | |
start_frame = int(start * sample_rate) | |
end_frame = int(end * sample_rate) | |
sndfile.seek(start_frame) | |
sample = sndfile.read(frames=end_frame - start_frame, dtype='float32') | |
return sample, sample_rate | |
# | |
def augment(self, sample, sample_rate): | |
for config in self.augment_configs: | |
if config['type'] == 'speed' and random.random() < config['prob']: | |
if self.speed_rates is None: | |
min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \ | |
config['params']['max_speed_rate'], config['params']['num_rates'] | |
self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True) | |
rate = random.choice(self.speed_rates) | |
sample = self.change_speed(sample, speed_rate=rate) | |
if config['type'] == 'shift' and random.random() < config['prob']: | |
min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms'] | |
shift_ms = random.randint(min_shift_ms, max_shift_ms) | |
sample = self.shift(sample, sample_rate, shift_ms=shift_ms) | |
if config['type'] == 'volume' and random.random() < config['prob']: | |
min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS'] | |
gain = random.randint(min_gain_dBFS, max_gain_dBFS) | |
sample = self.volume(sample, gain=gain) | |
if config['type'] == 'resample' and random.random() < config['prob']: | |
new_sample_rates = config['params']['new_sample_rates'] | |
new_sample_rate = np.random.choice(new_sample_rates) | |
sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate) | |
sample_rate = new_sample_rate | |
if config['type'] == 'noise' and random.random() < config['prob']: | |
min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB'] | |
if self.noises_path is None: | |
self.noises_path = [] | |
noise_dir = config['params']['noise_dir'] | |
if os.path.exists(noise_dir): | |
for file in os.listdir(noise_dir): | |
self.noises_path.append(os.path.join(noise_dir, file)) | |
noise_path = random.choice(self.noises_path) | |
snr_dB = random.randint(min_snr_dB, max_snr_dB) | |
sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB) | |
return sample, sample_rate | |
# | |
def change_speed(sample, speed_rate): | |
if speed_rate == 1.0: | |
return sample | |
if speed_rate <= 0: | |
raise ValueError("error") | |
old_length = sample.shape[0] | |
new_length = int(old_length / speed_rate) | |
old_indices = np.arange(old_length) | |
new_indices = np.linspace(start=0, stop=old_length, num=new_length) | |
sample = np.interp(new_indices, old_indices, sample).astype(np.float32) | |
return sample | |
# | |
def shift(sample, sample_rate, shift_ms): | |
duration = sample.shape[0] / sample_rate | |
if abs(shift_ms) / 1000.0 > duration: | |
raise ValueError("shift_ms") | |
shift_samples = int(shift_ms * sample_rate / 1000) | |
if shift_samples > 0: | |
sample[:-shift_samples] = sample[shift_samples:] | |
sample[-shift_samples:] = 0 | |
elif shift_samples < 0: | |
sample[-shift_samples:] = sample[:shift_samples] | |
sample[:-shift_samples] = 0 | |
return sample | |
# | |
def volume(sample, gain): | |
sample *= 10.**(gain / 20.) | |
return | |
# | |
def resample(sample, orig_sr, target_sr): | |
sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr) | |
return sample | |
# | |
def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0): | |
noise_sample, sr = librosa.load(noise_path, sr=sample_rate) | |
# | |
target_db = -20 | |
gain = min(max_gain_db, target_db - self.rms_db(sample)) | |
sample *= 10. ** (gain / 20.) | |
# | |
sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample) | |
noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db) | |
noise_sample *= 10. ** (noise_gain_db / 20.) | |
# | |
if noise_sample.shape[0] < sample.shape[0]: | |
diff_duration = sample.shape[0] - noise_sample.shape[0] | |
noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap') | |
elif noise_sample.shape[0] > sample.shape[0]: | |
start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0]) | |
noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame] | |
sample += noise_sample | |
return sample | |
def rms_db(sample): | |
mean_square = np.mean(sample ** 2) | |
return 10 * np.log10(mean_square) | |