Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import csv | |
from pathlib import Path | |
import zipfile | |
from functools import reduce | |
from multiprocessing import cpu_count | |
from typing import Any, Dict, List, Optional, Union | |
import io | |
import numpy as np | |
import pandas as pd | |
import sentencepiece as sp | |
from fairseq.data.audio.audio_utils import ( | |
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data, | |
is_sf_audio_data | |
) | |
import torch | |
import soundfile as sf | |
from tqdm import tqdm | |
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3 | |
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0 | |
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2 | |
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1 | |
def gen_vocab( | |
input_path: Path, output_path_prefix: Path, model_type="bpe", | |
vocab_size=1000, special_symbols: Optional[List[str]] = None | |
): | |
# Train SentencePiece Model | |
arguments = [ | |
f"--input={input_path.as_posix()}", | |
f"--model_prefix={output_path_prefix.as_posix()}", | |
f"--model_type={model_type}", | |
f"--vocab_size={vocab_size}", | |
"--character_coverage=1.0", | |
f"--num_threads={cpu_count()}", | |
f"--unk_id={UNK_TOKEN_ID}", | |
f"--bos_id={BOS_TOKEN_ID}", | |
f"--eos_id={EOS_TOKEN_ID}", | |
f"--pad_id={PAD_TOKEN_ID}", | |
] | |
if special_symbols is not None: | |
_special_symbols = ",".join(special_symbols) | |
arguments.append(f"--user_defined_symbols={_special_symbols}") | |
sp.SentencePieceTrainer.Train(" ".join(arguments)) | |
# Export fairseq dictionary | |
spm = sp.SentencePieceProcessor() | |
spm.Load(output_path_prefix.as_posix() + ".model") | |
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())} | |
assert ( | |
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN | |
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN | |
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN | |
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN | |
) | |
vocab = { | |
i: s | |
for i, s in vocab.items() | |
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN} | |
} | |
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out: | |
for _, s in sorted(vocab.items(), key=lambda x: x[0]): | |
f_out.write(f"{s} 1\n") | |
def extract_fbank_features( | |
waveform: torch.FloatTensor, | |
sample_rate: int, | |
output_path: Optional[Path] = None, | |
n_mel_bins: int = 80, | |
overwrite: bool = False, | |
): | |
if output_path is not None and output_path.is_file() and not overwrite: | |
return | |
_waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True) | |
# Kaldi compliance: 16-bit signed integers | |
_waveform = _waveform * (2 ** 15) | |
_waveform = _waveform.numpy() | |
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) | |
if features is None: | |
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins) | |
if features is None: | |
raise ImportError( | |
"Please install pyKaldi or torchaudio to enable fbank feature extraction" | |
) | |
if output_path is not None: | |
np.save(output_path.as_posix(), features) | |
return features | |
def create_zip(data_root: Path, zip_path: Path): | |
paths = list(data_root.glob("*.npy")) | |
paths.extend(data_root.glob("*.flac")) | |
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f: | |
for path in tqdm(paths): | |
f.write(path, arcname=path.name) | |
def get_zip_manifest( | |
zip_path: Path, zip_root: Optional[Path] = None, is_audio=False | |
): | |
_zip_path = Path.joinpath(zip_root or Path(""), zip_path) | |
with zipfile.ZipFile(_zip_path, mode="r") as f: | |
info = f.infolist() | |
paths, lengths = {}, {} | |
for i in tqdm(info): | |
utt_id = Path(i.filename).stem | |
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size | |
paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}" | |
with open(_zip_path, "rb") as f: | |
f.seek(offset) | |
byte_data = f.read(file_size) | |
assert len(byte_data) > 1 | |
if is_audio: | |
assert is_sf_audio_data(byte_data), i | |
else: | |
assert is_npy_data(byte_data), i | |
byte_data_fp = io.BytesIO(byte_data) | |
if is_audio: | |
lengths[utt_id] = sf.info(byte_data_fp).frames | |
else: | |
lengths[utt_id] = np.load(byte_data_fp).shape[0] | |
return paths, lengths | |
def gen_config_yaml( | |
manifest_root: Path, | |
spm_filename: Optional[str] = None, | |
vocab_name: Optional[str] = None, | |
yaml_filename: str = "config.yaml", | |
specaugment_policy: Optional[str] = "lb", | |
prepend_tgt_lang_tag: bool = False, | |
sampling_alpha: Optional[float] = None, | |
input_channels: Optional[int] = 1, | |
input_feat_per_channel: Optional[int] = 80, | |
audio_root: str = "", | |
cmvn_type: str = "utterance", | |
gcmvn_path: Optional[Path] = None, | |
extra=None | |
): | |
manifest_root = manifest_root.absolute() | |
writer = S2TDataConfigWriter(manifest_root / yaml_filename) | |
assert spm_filename is not None or vocab_name is not None | |
vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \ | |
else vocab_name | |
writer.set_vocab_filename(vocab_name) | |
if input_channels is not None: | |
writer.set_input_channels(input_channels) | |
if input_feat_per_channel is not None: | |
writer.set_input_feat_per_channel(input_feat_per_channel) | |
specaugment_setters = { | |
"lb": writer.set_specaugment_lb_policy, | |
"ld": writer.set_specaugment_ld_policy, | |
"sm": writer.set_specaugment_sm_policy, | |
"ss": writer.set_specaugment_ss_policy, | |
} | |
specaugment_setter = specaugment_setters.get(specaugment_policy, None) | |
if specaugment_setter is not None: | |
specaugment_setter() | |
if spm_filename is not None: | |
writer.set_bpe_tokenizer( | |
{ | |
"bpe": "sentencepiece", | |
"sentencepiece_model": (manifest_root / spm_filename).as_posix(), | |
} | |
) | |
if prepend_tgt_lang_tag: | |
writer.set_prepend_tgt_lang_tag(True) | |
if sampling_alpha is not None: | |
writer.set_sampling_alpha(sampling_alpha) | |
if cmvn_type not in ["global", "utterance"]: | |
raise NotImplementedError | |
if specaugment_policy is not None: | |
writer.set_feature_transforms( | |
"_train", [f"{cmvn_type}_cmvn", "specaugment"] | |
) | |
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) | |
if cmvn_type == "global": | |
if gcmvn_path is None: | |
raise ValueError("Please provide path of global cmvn file.") | |
else: | |
writer.set_global_cmvn(gcmvn_path.as_posix()) | |
if len(audio_root) > 0: | |
writer.set_audio_root(audio_root) | |
if extra is not None: | |
writer.set_extra(extra) | |
writer.flush() | |
def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame: | |
_path = path if isinstance(path, str) else path.as_posix() | |
return pd.read_csv( | |
_path, | |
sep="\t", | |
header=0, | |
encoding="utf-8", | |
escapechar="\\", | |
quoting=csv.QUOTE_NONE, | |
na_filter=False, | |
) | |
def save_df_to_tsv(dataframe, path: Union[str, Path]): | |
_path = path if isinstance(path, str) else path.as_posix() | |
dataframe.to_csv( | |
_path, | |
sep="\t", | |
header=True, | |
index=False, | |
encoding="utf-8", | |
escapechar="\\", | |
quoting=csv.QUOTE_NONE, | |
) | |
def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]: | |
with open(path, "r") as f: | |
reader = csv.DictReader( | |
f, | |
delimiter="\t", | |
quotechar=None, | |
doublequote=False, | |
lineterminator="\n", | |
quoting=csv.QUOTE_NONE, | |
) | |
rows = [dict(e) for e in reader] | |
return rows | |
def filter_manifest_df( | |
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000 | |
): | |
filters = { | |
"no speech": df["audio"] == "", | |
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames, | |
"empty sentence": df["tgt_text"] == "", | |
} | |
if is_train_split: | |
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames | |
if extra_filters is not None: | |
filters.update(extra_filters) | |
invalid = reduce(lambda x, y: x | y, filters.values()) | |
valid = ~invalid | |
print( | |
"| " | |
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items()) | |
+ f", total {invalid.sum()} filtered, {valid.sum()} remained." | |
) | |
return df[valid] | |
def cal_gcmvn_stats(features_list): | |
features = np.concatenate(features_list) | |
square_sums = (features ** 2).sum(axis=0) | |
mean = features.mean(axis=0) | |
features = np.subtract(features, mean) | |
var = square_sums / features.shape[0] - mean ** 2 | |
std = np.sqrt(np.maximum(var, 1e-8)) | |
return {"mean": mean.astype("float32"), "std": std.astype("float32")} | |
class S2TDataConfigWriter(object): | |
DEFAULT_VOCAB_FILENAME = "dict.txt" | |
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 | |
DEFAULT_INPUT_CHANNELS = 1 | |
def __init__(self, yaml_path: Path): | |
try: | |
import yaml | |
except ImportError: | |
print("Please install PyYAML for S2T data config YAML files") | |
self.yaml = yaml | |
self.yaml_path = yaml_path | |
self.config = {} | |
def flush(self): | |
with open(self.yaml_path, "w") as f: | |
self.yaml.dump(self.config, f) | |
def set_audio_root(self, audio_root=""): | |
self.config["audio_root"] = audio_root | |
def set_vocab_filename(self, vocab_filename: str = "dict.txt"): | |
self.config["vocab_filename"] = vocab_filename | |
def set_specaugment( | |
self, | |
time_wrap_w: int, | |
freq_mask_n: int, | |
freq_mask_f: int, | |
time_mask_n: int, | |
time_mask_t: int, | |
time_mask_p: float, | |
): | |
self.config["specaugment"] = { | |
"time_wrap_W": time_wrap_w, | |
"freq_mask_N": freq_mask_n, | |
"freq_mask_F": freq_mask_f, | |
"time_mask_N": time_mask_n, | |
"time_mask_T": time_mask_t, | |
"time_mask_p": time_mask_p, | |
} | |
def set_specaugment_lb_policy(self): | |
self.set_specaugment( | |
time_wrap_w=0, | |
freq_mask_n=1, | |
freq_mask_f=27, | |
time_mask_n=1, | |
time_mask_t=100, | |
time_mask_p=1.0, | |
) | |
def set_specaugment_ld_policy(self): | |
self.set_specaugment( | |
time_wrap_w=0, | |
freq_mask_n=2, | |
freq_mask_f=27, | |
time_mask_n=2, | |
time_mask_t=100, | |
time_mask_p=1.0, | |
) | |
def set_specaugment_sm_policy(self): | |
self.set_specaugment( | |
time_wrap_w=0, | |
freq_mask_n=2, | |
freq_mask_f=15, | |
time_mask_n=2, | |
time_mask_t=70, | |
time_mask_p=0.2, | |
) | |
def set_specaugment_ss_policy(self): | |
self.set_specaugment( | |
time_wrap_w=0, | |
freq_mask_n=2, | |
freq_mask_f=27, | |
time_mask_n=2, | |
time_mask_t=70, | |
time_mask_p=0.2, | |
) | |
def set_input_channels(self, input_channels: int = 1): | |
self.config["input_channels"] = input_channels | |
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): | |
self.config["input_feat_per_channel"] = input_feat_per_channel | |
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): | |
self.config["bpe_tokenizer"] = bpe_tokenizer | |
def set_global_cmvn(self, stats_npz_path: str): | |
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path} | |
def set_feature_transforms(self, split: str, transforms: List[str]): | |
if "transforms" not in self.config: | |
self.config["transforms"] = {} | |
self.config["transforms"][split] = transforms | |
def set_prepend_tgt_lang_tag(self, flag: bool = True): | |
self.config["prepend_tgt_lang_tag"] = flag | |
def set_sampling_alpha(self, sampling_alpha: float = 1.0): | |
self.config["sampling_alpha"] = sampling_alpha | |
def set_extra(self, data): | |
self.config.update(data) | |