balaramas commited on
Commit
a72450c
1 Parent(s): d9a7ea1

Upload prep_mustc_data.py

Browse files
Files changed (1) hide show
  1. prep_mustc_data.py +294 -0
prep_mustc_data.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+ import shutil
12
+ from itertools import groupby
13
+ from tempfile import NamedTemporaryFile
14
+ from typing import Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import soundfile as sf
19
+ from examples.speech_to_text.data_utils import (
20
+ create_zip,
21
+ extract_fbank_features,
22
+ filter_manifest_df,
23
+ gen_config_yaml,
24
+ gen_vocab,
25
+ get_zip_manifest,
26
+ load_df_from_tsv,
27
+ save_df_to_tsv,
28
+ cal_gcmvn_stats,
29
+ )
30
+ import torch
31
+ from torch.utils.data import Dataset
32
+ from tqdm import tqdm
33
+
34
+ from fairseq.data.audio.audio_utils import get_waveform, convert_waveform
35
+
36
+
37
+ log = logging.getLogger(__name__)
38
+
39
+
40
+ MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
41
+
42
+
43
+ class MUSTC(Dataset):
44
+ """
45
+ Create a Dataset for MuST-C. Each item is a tuple of the form:
46
+ waveform, sample_rate, source utterance, target utterance, speaker_id,
47
+ utterance_id
48
+ """
49
+
50
+ SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
51
+ LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru", "hi", "bn"]
52
+
53
+ def __init__(self, root: str, lang: str, split: str) -> None:
54
+ assert split in self.SPLITS and lang in self.LANGUAGES
55
+ _root = Path(root) / f"en-{lang}" / "data" / split
56
+ wav_root, txt_root = _root / "wav", _root / "txt"
57
+ assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
58
+ # Load audio segments
59
+ try:
60
+ import yaml
61
+ except ImportError:
62
+ print("Please install PyYAML to load the MuST-C YAML files")
63
+ with open(txt_root / f"{split}.yaml") as f:
64
+ segments = yaml.load(f, Loader=yaml.BaseLoader)
65
+ # Load source and target utterances
66
+ for _lang in ["en", lang]:
67
+ with open(txt_root / f"{split}.{_lang}") as f:
68
+ utterances = [r.strip() for r in f]
69
+ assert len(segments) == len(utterances)
70
+ for i, u in enumerate(utterances):
71
+ segments[i][_lang] = u
72
+ # Gather info
73
+ self.data = []
74
+ for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
75
+ wav_path = wav_root / wav_filename
76
+ sample_rate = sf.info(wav_path.as_posix()).samplerate
77
+ seg_group = sorted(_seg_group, key=lambda x: x["offset"])
78
+ for i, segment in enumerate(seg_group):
79
+ offset = int(float(segment["offset"]) * sample_rate)
80
+ n_frames = int(float(segment["duration"]) * sample_rate)
81
+ _id = f"{wav_path.stem}_{i}"
82
+ self.data.append(
83
+ (
84
+ wav_path.as_posix(),
85
+ offset,
86
+ n_frames,
87
+ sample_rate,
88
+ segment["en"],
89
+ segment[lang],
90
+ segment["speaker_id"],
91
+ _id,
92
+ )
93
+ )
94
+
95
+ def __getitem__(
96
+ self, n: int
97
+ ) -> Tuple[torch.Tensor, int, str, str, str, str]:
98
+ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, \
99
+ utt_id = self.data[n]
100
+ waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
101
+ waveform = torch.from_numpy(waveform)
102
+ return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
103
+
104
+ def __len__(self) -> int:
105
+ return len(self.data)
106
+
107
+
108
+ def process(args):
109
+ root = Path(args.data_root).absolute()
110
+ for lang in MUSTC.LANGUAGES:
111
+ cur_root = root / f"en-{lang}"
112
+ if not cur_root.is_dir():
113
+ print(f"{cur_root.as_posix()} does not exist. Skipped.")
114
+ continue
115
+ # Extract features
116
+ audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80")
117
+ audio_root.mkdir(exist_ok=True)
118
+
119
+ for split in MUSTC.SPLITS:
120
+ print(f"Fetching split {split}...")
121
+ dataset = MUSTC(root.as_posix(), lang, split)
122
+ if args.use_audio_input:
123
+ print("Converting audios...")
124
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
125
+ tgt_sample_rate = 16_000
126
+ _wavform, _ = convert_waveform(
127
+ waveform, sample_rate, to_mono=True,
128
+ to_sample_rate=tgt_sample_rate
129
+ )
130
+ sf.write(
131
+ (audio_root / f"{utt_id}.flac").as_posix(),
132
+ _wavform.T.numpy(), tgt_sample_rate
133
+ )
134
+ else:
135
+ print("Extracting log mel filter bank features...")
136
+ gcmvn_feature_list = []
137
+ if split == 'train' and args.cmvn_type == "global":
138
+ print("And estimating cepstral mean and variance stats...")
139
+
140
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
141
+ features = extract_fbank_features(
142
+ waveform, sample_rate, audio_root / f"{utt_id}.npy"
143
+ )
144
+ if split == 'train' and args.cmvn_type == "global":
145
+ if len(gcmvn_feature_list) < args.gcmvn_max_num:
146
+ gcmvn_feature_list.append(features)
147
+
148
+ if split == 'train' and args.cmvn_type == "global":
149
+ # Estimate and save cmv
150
+ stats = cal_gcmvn_stats(gcmvn_feature_list)
151
+ with open(cur_root / "gcmvn.npz", "wb") as f:
152
+ np.savez(f, mean=stats["mean"], std=stats["std"])
153
+
154
+ # Pack features into ZIP
155
+ zip_path = cur_root / f"{audio_root.name}.zip"
156
+ print("ZIPing audios/features...")
157
+ create_zip(audio_root, zip_path)
158
+ print("Fetching ZIP manifest...")
159
+ audio_paths, audio_lengths = get_zip_manifest(
160
+ zip_path,
161
+ is_audio=args.use_audio_input,
162
+ )
163
+ # Generate TSV manifest
164
+ print("Generating manifest...")
165
+ train_text = []
166
+ for split in MUSTC.SPLITS:
167
+ is_train_split = split.startswith("train")
168
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
169
+ dataset = MUSTC(args.data_root, lang, split)
170
+ for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
171
+ manifest["id"].append(utt_id)
172
+ manifest["audio"].append(audio_paths[utt_id])
173
+ manifest["n_frames"].append(audio_lengths[utt_id])
174
+ manifest["tgt_text"].append(
175
+ src_utt if args.task == "asr" else tgt_utt
176
+ )
177
+ manifest["speaker"].append(speaker_id)
178
+ if is_train_split:
179
+ train_text.extend(manifest["tgt_text"])
180
+ df = pd.DataFrame.from_dict(manifest)
181
+ df = filter_manifest_df(df, is_train_split=is_train_split)
182
+ save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
183
+ # Generate vocab
184
+ v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
185
+ spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
186
+ with NamedTemporaryFile(mode="w") as f:
187
+ for t in train_text:
188
+ f.write(t + "\n")
189
+ gen_vocab(
190
+ Path(f.name),
191
+ cur_root / spm_filename_prefix,
192
+ args.vocab_type,
193
+ args.vocab_size,
194
+ )
195
+ # Generate config YAML
196
+ if args.use_audio_input:
197
+ gen_config_yaml(
198
+ cur_root,
199
+ spm_filename=spm_filename_prefix + ".model",
200
+ yaml_filename=f"config_{args.task}.yaml",
201
+ specaugment_policy=None,
202
+ extra={"use_audio_input": True}
203
+ )
204
+ else:
205
+ gen_config_yaml(
206
+ cur_root,
207
+ spm_filename=spm_filename_prefix + ".model",
208
+ yaml_filename=f"config_{args.task}.yaml",
209
+ specaugment_policy="lb",
210
+ cmvn_type=args.cmvn_type,
211
+ gcmvn_path=(
212
+ cur_root / "gcmvn.npz" if args.cmvn_type == "global"
213
+ else None
214
+ ),
215
+ )
216
+ # Clean up
217
+ shutil.rmtree(audio_root)
218
+
219
+
220
+ def process_joint(args):
221
+ cur_root = Path(args.data_root)
222
+ assert all(
223
+ (cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES
224
+ ), "do not have downloaded data available for all 8 languages"
225
+ # Generate vocab
226
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
227
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
228
+ with NamedTemporaryFile(mode="w") as f:
229
+ for lang in MUSTC.LANGUAGES:
230
+ tsv_path = cur_root / f"en-{lang}" / f"train_{args.task}.tsv"
231
+ df = load_df_from_tsv(tsv_path)
232
+ for t in df["tgt_text"]:
233
+ f.write(t + "\n")
234
+ special_symbols = None
235
+ if args.task == 'st':
236
+ special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES]
237
+ gen_vocab(
238
+ Path(f.name),
239
+ cur_root / spm_filename_prefix,
240
+ args.vocab_type,
241
+ args.vocab_size,
242
+ special_symbols=special_symbols
243
+ )
244
+ # Generate config YAML
245
+ gen_config_yaml(
246
+ cur_root,
247
+ spm_filename=spm_filename_prefix + ".model",
248
+ yaml_filename=f"config_{args.task}.yaml",
249
+ specaugment_policy="ld",
250
+ prepend_tgt_lang_tag=(args.task == "st"),
251
+ )
252
+ # Make symbolic links to manifests
253
+ for lang in MUSTC.LANGUAGES:
254
+ for split in MUSTC.SPLITS:
255
+ src_path = cur_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
256
+ desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
257
+ if not desc_path.is_symlink():
258
+ os.symlink(src_path, desc_path)
259
+
260
+
261
+ def main():
262
+ parser = argparse.ArgumentParser()
263
+ parser.add_argument("--data-root", "-d", required=True, type=str)
264
+ parser.add_argument(
265
+ "--vocab-type",
266
+ default="unigram",
267
+ required=True,
268
+ type=str,
269
+ choices=["bpe", "unigram", "char"],
270
+ ),
271
+ parser.add_argument("--vocab-size", default=8000, type=int)
272
+ parser.add_argument("--task", type=str, choices=["asr", "st"])
273
+ parser.add_argument("--joint", action="store_true", help="")
274
+ parser.add_argument(
275
+ "--cmvn-type", default="utterance",
276
+ choices=["global", "utterance"],
277
+ help="The type of cepstral mean and variance normalization"
278
+ )
279
+ parser.add_argument(
280
+ "--gcmvn-max-num", default=150000, type=int,
281
+ help="Maximum number of sentences to use to estimate global mean and "
282
+ "variance"
283
+ )
284
+ parser.add_argument("--use-audio-input", action="store_true")
285
+ args = parser.parse_args()
286
+
287
+ if args.joint:
288
+ process_joint(args)
289
+ else:
290
+ process(args)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()