Plachta commited on
Commit
3c72736
·
1 Parent(s): 11433cb
Files changed (1) hide show
  1. data/tokenizer.py +0 -260
data/tokenizer.py CHANGED
@@ -22,160 +22,6 @@ import torch
22
  import torchaudio
23
  from encodec import EncodecModel
24
  from encodec.utils import convert_audio
25
- from phonemizer.backend import EspeakBackend
26
- from phonemizer.backend.espeak.language_switch import LanguageSwitch
27
- from phonemizer.backend.espeak.words_mismatch import WordMismatch
28
- from phonemizer.punctuation import Punctuation
29
- from phonemizer.separator import Separator
30
- from phonemizer.separator import Separator
31
-
32
- try:
33
- from pypinyin import Style, pinyin
34
- from pypinyin.style._utils import get_finals, get_initials
35
- except Exception:
36
- pass
37
-
38
-
39
- class PypinyinBackend:
40
- """PypinyinBackend for Chinese. Most codes is referenced from espnet.
41
- There are two types pinyin or initials_finals, one is
42
- just like "ni1 hao3", the other is like "n i1 h ao3".
43
- """
44
-
45
- def __init__(
46
- self,
47
- backend="initials_finals",
48
- punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
49
- ) -> None:
50
- self.backend = backend
51
- self.punctuation_marks = punctuation_marks
52
-
53
- def phonemize(
54
- self, text: List[str], separator: Separator, strip=True, njobs=1
55
- ) -> List[str]:
56
- assert isinstance(text, List)
57
- phonemized = []
58
- for _text in text:
59
- _text = re.sub(" +", " ", _text.strip())
60
- _text = _text.replace(" ", separator.word)
61
- phones = []
62
- if self.backend == "pypinyin":
63
- for n, py in enumerate(
64
- pinyin(
65
- _text, style=Style.TONE3, neutral_tone_with_five=True
66
- )
67
- ):
68
- if all([c in self.punctuation_marks for c in py[0]]):
69
- if len(phones):
70
- assert phones[-1] == separator.syllable
71
- phones.pop(-1)
72
-
73
- phones.extend(list(py[0]))
74
- else:
75
- phones.extend([py[0], separator.syllable])
76
- elif self.backend == "pypinyin_initials_finals":
77
- for n, py in enumerate(
78
- pinyin(
79
- _text, style=Style.TONE3, neutral_tone_with_five=True
80
- )
81
- ):
82
- if all([c in self.punctuation_marks for c in py[0]]):
83
- if len(phones):
84
- assert phones[-1] == separator.syllable
85
- phones.pop(-1)
86
- phones.extend(list(py[0]))
87
- else:
88
- if py[0][-1].isalnum():
89
- initial = get_initials(py[0], strict=False)
90
- if py[0][-1].isdigit():
91
- final = (
92
- get_finals(py[0][:-1], strict=False)
93
- + py[0][-1]
94
- )
95
- else:
96
- final = get_finals(py[0], strict=False)
97
- phones.extend(
98
- [
99
- initial,
100
- separator.phone,
101
- final,
102
- separator.syllable,
103
- ]
104
- )
105
- else:
106
- assert ValueError
107
- else:
108
- raise NotImplementedError
109
- phonemized.append(
110
- "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
111
- )
112
- return phonemized
113
-
114
-
115
- class TextTokenizer:
116
- """Phonemize Text."""
117
-
118
- def __init__(
119
- self,
120
- language="en-us",
121
- backend="espeak",
122
- separator=Separator(word="_", syllable="-", phone="|"),
123
- preserve_punctuation=True,
124
- punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
125
- with_stress: bool = False,
126
- tie: Union[bool, str] = False,
127
- language_switch: LanguageSwitch = "keep-flags",
128
- words_mismatch: WordMismatch = "ignore",
129
- ) -> None:
130
- if backend == "espeak":
131
- phonemizer = EspeakBackend(
132
- language,
133
- punctuation_marks=punctuation_marks,
134
- preserve_punctuation=preserve_punctuation,
135
- with_stress=with_stress,
136
- tie=tie,
137
- language_switch=language_switch,
138
- words_mismatch=words_mismatch,
139
- )
140
- elif backend in ["pypinyin", "pypinyin_initials_finals"]:
141
- phonemizer = PypinyinBackend(
142
- backend=backend,
143
- punctuation_marks=punctuation_marks + separator.word,
144
- )
145
- else:
146
- raise NotImplementedError(f"{backend}")
147
-
148
- self.backend = phonemizer
149
- self.separator = separator
150
-
151
- def to_list(self, phonemized: str) -> List[str]:
152
- fields = []
153
- for word in phonemized.split(self.separator.word):
154
- # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
155
- pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
156
- fields.extend(
157
- [p for p in pp if p != self.separator.phone]
158
- + [self.separator.word]
159
- )
160
- assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
161
- self.separator.phone
162
- )
163
- return fields[:-1]
164
-
165
- def __call__(self, text, strip=True) -> List[List[str]]:
166
- if isinstance(text, str):
167
- text = [text]
168
-
169
- phonemized = self.backend.phonemize(
170
- text, separator=self.separator, strip=strip, njobs=1
171
- )
172
- return [self.to_list(p) for p in phonemized]
173
-
174
-
175
- def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
176
- phonemes = tokenizer([text.strip()])
177
- return phonemes[0] # k2symbols
178
-
179
 
180
  def remove_encodec_weight_norm(model):
181
  from encodec.modules import SConv1d
@@ -256,112 +102,6 @@ def tokenize_audio(tokenizer: AudioTokenizer, audio):
256
  return encoded_frames
257
 
258
 
259
- # @dataclass
260
- # class AudioTokenConfig:
261
- # frame_shift: Seconds = 320.0 / 24000
262
- # num_quantizers: int = 8
263
- #
264
- # def to_dict(self) -> Dict[str, Any]:
265
- # return asdict(self)
266
- #
267
- # @staticmethod
268
- # def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
269
- # return AudioTokenConfig(**data)
270
- #
271
- #
272
- # class AudioTokenExtractor(FeatureExtractor):
273
- # name = "encodec"
274
- # config_type = AudioTokenConfig
275
- #
276
- # def __init__(self, config: Optional[Any] = None):
277
- # super(AudioTokenExtractor, self).__init__(config)
278
- # self.tokenizer = AudioTokenizer()
279
- #
280
- # def extract(
281
- # self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
282
- # ) -> np.ndarray:
283
- # if not isinstance(samples, torch.Tensor):
284
- # samples = torch.from_numpy(samples)
285
- # if sampling_rate != self.tokenizer.sample_rate:
286
- # samples = convert_audio(
287
- # samples,
288
- # sampling_rate,
289
- # self.tokenizer.sample_rate,
290
- # self.tokenizer.channels,
291
- # )
292
- # if len(samples.shape) == 2:
293
- # samples = samples.unsqueeze(0)
294
- # else:
295
- # raise ValueError()
296
- #
297
- # device = self.tokenizer.device
298
- # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
299
- # codes = encoded_frames[0][0] # [B, n_q, T]
300
- # if True:
301
- # duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
302
- # expected_num_frames = compute_num_frames(
303
- # duration=duration,
304
- # frame_shift=self.frame_shift,
305
- # sampling_rate=sampling_rate,
306
- # )
307
- # assert abs(codes.shape[-1] - expected_num_frames) <= 1
308
- # codes = codes[..., :expected_num_frames]
309
- # return codes.cpu().squeeze(0).permute(1, 0).numpy()
310
- #
311
- # @property
312
- # def frame_shift(self) -> Seconds:
313
- # return self.config.frame_shift
314
- #
315
- # def feature_dim(self, sampling_rate: int) -> int:
316
- # return self.config.num_quantizers
317
- #
318
- # def pad_tensor_list(self, tensor_list, device, padding_value=0):
319
- # # 计算每个张量的长度
320
- # lengths = [tensor.shape[0] for tensor in tensor_list]
321
- # # 使用pad_sequence函数进行填充
322
- # tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
323
- # padded_tensor = torch.nn.utils.rnn.pad_sequence(
324
- # tensor_list, batch_first=True, padding_value=padding_value
325
- # )
326
- # return padded_tensor, lengths
327
- #
328
- # def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
329
- # samples = [wav.squeeze() for wav in samples]
330
- # device = self.tokenizer.device
331
- # samples, lengths = self.pad_tensor_list(samples, device)
332
- # samples = samples.unsqueeze(1)
333
- #
334
- # if not isinstance(samples, torch.Tensor):
335
- # samples = torch.from_numpy(samples)
336
- # if len(samples.shape) != 3:
337
- # raise ValueError()
338
- # if sampling_rate != self.tokenizer.sample_rate:
339
- # samples = [
340
- # convert_audio(
341
- # wav,
342
- # sampling_rate,
343
- # self.tokenizer.sample_rate,
344
- # self.tokenizer.channels,
345
- # )
346
- # for wav in samples
347
- # ]
348
- # # Extract discrete codes from EnCodec
349
- # with torch.no_grad():
350
- # encoded_frames = self.tokenizer.encode(samples.detach().to(device))
351
- # encoded_frames = encoded_frames[0][0] # [B, n_q, T]
352
- # batch_codes = []
353
- # for b, length in enumerate(lengths):
354
- # codes = encoded_frames[b]
355
- # duration = round(length / sampling_rate, ndigits=12)
356
- # expected_num_frames = compute_num_frames(
357
- # duration=duration,
358
- # frame_shift=self.frame_shift,
359
- # sampling_rate=sampling_rate,
360
- # )
361
- # batch_codes.append(codes[..., :expected_num_frames])
362
- # return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
363
-
364
-
365
  if __name__ == "__main__":
366
  model = EncodecModel.encodec_model_24khz()
367
  model.set_target_bandwidth(6.0)
 
22
  import torchaudio
23
  from encodec import EncodecModel
24
  from encodec.utils import convert_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def remove_encodec_weight_norm(model):
27
  from encodec.modules import SConv1d
 
102
  return encoded_frames
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if __name__ == "__main__":
106
  model = EncodecModel.encodec_model_24khz()
107
  model.set_target_bandwidth(6.0)