Spaces:
Running
Running
| import json | |
| import os | |
| import random | |
| import sys | |
| from typing import List | |
| import librosa | |
| import numpy as np | |
| import soundfile | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from utils.binary import DatasetReader | |
| class CustomDataset(Dataset): | |
| def __init__(self, | |
| data_list_path, | |
| processor, | |
| mono=True, | |
| language=None, | |
| timestamps=False, | |
| sample_rate=16000, | |
| min_duration=0.5, | |
| max_duration=30, | |
| augment_config_path=None): | |
| """ | |
| Args: | |
| data_list_path: | |
| processor: Whisper | |
| mono: True | |
| language: | |
| timestamps: | |
| sample_rate: 16000 | |
| min_duration: 0.5s | |
| max_duration: 30s | |
| augment_config_path: | |
| """ | |
| super(CustomDataset, self).__init__() | |
| assert min_duration >= 0.5, f"min_duration 0.5:{min_duration}" | |
| assert max_duration <= 30, f"max_duration 30:{max_duration}" | |
| self.data_list_path = data_list_path | |
| self.processor = processor | |
| self.data_list_path = data_list_path | |
| self.sample_rate = sample_rate | |
| self.mono = mono | |
| self.language = language | |
| self.timestamps = timestamps | |
| self.min_duration = min_duration | |
| self.max_duration = max_duration | |
| self.vocab = self.processor.tokenizer.get_vocab() | |
| self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1 | |
| self.startoftranscript = self.vocab['<|startoftranscript|>'] | |
| self.endoftext = self.vocab['<|endoftext|>'] | |
| self.nocaptions = self.vocab['<|nocaptions|>'] | |
| self.data_list: List[dict] = [] | |
| # | |
| self._load_data_list() | |
| # | |
| self.augment_configs = None | |
| self.noises_path = None | |
| self.speed_rates = None | |
| if augment_config_path: | |
| with open(augment_config_path, 'r', encoding='utf-8') as f: | |
| self.augment_configs = json.load(f) | |
| # | |
| def _load_data_list(self): | |
| if self.data_list_path.endswith(".header"): | |
| # | |
| self.dataset_reader = DatasetReader(data_header_path=self.data_list_path, | |
| min_duration=self.min_duration, | |
| max_duration=self.max_duration) | |
| self.data_list = self.dataset_reader.get_keys() | |
| else: | |
| # | |
| with open(self.data_list_path, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| self.data_list = [] | |
| for line in tqdm(lines, desc=''): | |
| if isinstance(line, str): | |
| line = json.loads(line) | |
| if not isinstance(line, dict): continue | |
| # | |
| if line["duration"] < self.min_duration: | |
| continue | |
| if self.max_duration != -1 and line["duration"] > self.max_duration: | |
| continue | |
| self.data_list.append(dict(line)) | |
| # | |
| def _get_list_data(self, idx): | |
| if self.data_list_path.endswith(".header"): | |
| data_list = self.dataset_reader.get_data(self.data_list[idx]) | |
| else: | |
| data_list = self.data_list[idx] | |
| # | |
| audio_file = data_list["audio"]['path'] | |
| transcript = data_list["sentences"] if self.timestamps else data_list["sentence"] | |
| language = data_list["language"] if 'language' in data_list.keys() else None | |
| if 'start_time' not in data_list["audio"].keys(): | |
| sample, sample_rate = soundfile.read(audio_file, dtype='float32') | |
| else: | |
| start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"] | |
| # | |
| sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time) | |
| sample = sample.T | |
| # | |
| if self.mono: | |
| sample = librosa.to_mono(sample) | |
| # | |
| if self.augment_configs: | |
| sample, sample_rate = self.augment(sample, sample_rate) | |
| # | |
| if self.sample_rate != sample_rate: | |
| sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate) | |
| return sample, sample_rate, transcript, language | |
| def _load_timestamps_transcript(self, transcript: List[dict]): | |
| assert isinstance(transcript, list), f"transcript list:{type(transcript)}" | |
| data = dict() | |
| labels = self.processor.tokenizer.prefix_tokens[:3] | |
| for t in transcript: | |
| # | |
| start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01 | |
| start = self.timestamp_begin + round(start * 100) // 2 | |
| end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01 | |
| end = self.timestamp_begin + round(end * 100) // 2 | |
| label = self.processor(text=t['text']).input_ids[4:-1] | |
| labels.extend([start]) | |
| labels.extend(label) | |
| labels.extend([end]) | |
| data['labels'] = labels + [self.endoftext] | |
| return data | |
| def __getitem__(self, idx): | |
| try: | |
| # | |
| sample, sample_rate, transcript, language = self._get_list_data(idx=idx) | |
| # | |
| self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language) | |
| if len(transcript) > 0: | |
| # | |
| if self.timestamps: | |
| data = self._load_timestamps_transcript(transcript=transcript) | |
| # | |
| data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features | |
| else: | |
| # | |
| data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript) | |
| else: | |
| # | |
| data = self.processor(audio=sample, sampling_rate=self.sample_rate) | |
| data['labels'] = [self.startoftranscript, self.nocaptions, self.endoftext] | |
| return data | |
| except Exception as e: | |
| print(f'idx:{idx} error - {e}', file=sys.stderr) | |
| return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
| def __len__(self): | |
| return len(self.data_list) | |
| # | |
| def slice_from_file(file, start, end): | |
| sndfile = soundfile.SoundFile(file) | |
| sample_rate = sndfile.samplerate | |
| duration = round(float(len(sndfile)) / sample_rate, 3) | |
| start = round(start, 3) | |
| end = round(end, 3) | |
| # | |
| if start < 0.0: start += duration | |
| if end < 0.0: end += duration | |
| # | |
| if start < 0.0: start = 0.0 | |
| if end > duration: end = duration | |
| if end < 0.0: | |
| raise ValueError("(%f s)" % end) | |
| if start > end: | |
| raise ValueError("(%f s)(%f s)" % (start, end)) | |
| start_frame = int(start * sample_rate) | |
| end_frame = int(end * sample_rate) | |
| sndfile.seek(start_frame) | |
| sample = sndfile.read(frames=end_frame - start_frame, dtype='float32') | |
| return sample, sample_rate | |
| # | |
| def augment(self, sample, sample_rate): | |
| for config in self.augment_configs: | |
| if config['type'] == 'speed' and random.random() < config['prob']: | |
| if self.speed_rates is None: | |
| min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \ | |
| config['params']['max_speed_rate'], config['params']['num_rates'] | |
| self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True) | |
| rate = random.choice(self.speed_rates) | |
| sample = self.change_speed(sample, speed_rate=rate) | |
| if config['type'] == 'shift' and random.random() < config['prob']: | |
| min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms'] | |
| shift_ms = random.randint(min_shift_ms, max_shift_ms) | |
| sample = self.shift(sample, sample_rate, shift_ms=shift_ms) | |
| if config['type'] == 'volume' and random.random() < config['prob']: | |
| min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS'] | |
| gain = random.randint(min_gain_dBFS, max_gain_dBFS) | |
| sample = self.volume(sample, gain=gain) | |
| if config['type'] == 'resample' and random.random() < config['prob']: | |
| new_sample_rates = config['params']['new_sample_rates'] | |
| new_sample_rate = np.random.choice(new_sample_rates) | |
| sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate) | |
| sample_rate = new_sample_rate | |
| if config['type'] == 'noise' and random.random() < config['prob']: | |
| min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB'] | |
| if self.noises_path is None: | |
| self.noises_path = [] | |
| noise_dir = config['params']['noise_dir'] | |
| if os.path.exists(noise_dir): | |
| for file in os.listdir(noise_dir): | |
| self.noises_path.append(os.path.join(noise_dir, file)) | |
| noise_path = random.choice(self.noises_path) | |
| snr_dB = random.randint(min_snr_dB, max_snr_dB) | |
| sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB) | |
| return sample, sample_rate | |
| # | |
| def change_speed(sample, speed_rate): | |
| if speed_rate == 1.0: | |
| return sample | |
| if speed_rate <= 0: | |
| raise ValueError("error") | |
| old_length = sample.shape[0] | |
| new_length = int(old_length / speed_rate) | |
| old_indices = np.arange(old_length) | |
| new_indices = np.linspace(start=0, stop=old_length, num=new_length) | |
| sample = np.interp(new_indices, old_indices, sample).astype(np.float32) | |
| return sample | |
| # | |
| def shift(sample, sample_rate, shift_ms): | |
| duration = sample.shape[0] / sample_rate | |
| if abs(shift_ms) / 1000.0 > duration: | |
| raise ValueError("shift_ms") | |
| shift_samples = int(shift_ms * sample_rate / 1000) | |
| if shift_samples > 0: | |
| sample[:-shift_samples] = sample[shift_samples:] | |
| sample[-shift_samples:] = 0 | |
| elif shift_samples < 0: | |
| sample[-shift_samples:] = sample[:shift_samples] | |
| sample[:-shift_samples] = 0 | |
| return sample | |
| # | |
| def volume(sample, gain): | |
| sample *= 10.**(gain / 20.) | |
| return | |
| # | |
| def resample(sample, orig_sr, target_sr): | |
| sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr) | |
| return sample | |
| # | |
| def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0): | |
| noise_sample, sr = librosa.load(noise_path, sr=sample_rate) | |
| # | |
| target_db = -20 | |
| gain = min(max_gain_db, target_db - self.rms_db(sample)) | |
| sample *= 10. ** (gain / 20.) | |
| # | |
| sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample) | |
| noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db) | |
| noise_sample *= 10. ** (noise_gain_db / 20.) | |
| # | |
| if noise_sample.shape[0] < sample.shape[0]: | |
| diff_duration = sample.shape[0] - noise_sample.shape[0] | |
| noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap') | |
| elif noise_sample.shape[0] > sample.shape[0]: | |
| start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0]) | |
| noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame] | |
| sample += noise_sample | |
| return sample | |
| def rms_db(sample): | |
| mean_square = np.mean(sample ** 2) | |
| return 10 * np.log10(mean_square) | |