NaturalSpeech2 / text /text_token_collation.py
yuancwang
init
b725c5a
# Copyright (c) 2023 Amphion.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import List, Tuple
import os
import numpy as np
import torch
from text.symbol_table import SymbolTable
from text import text_to_sequence
"""
TextToken: map text to id
"""
# TextTokenCollator is modified from
# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/collation.py
class TextTokenCollator:
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
unique_tokens = [pad_symbol]
if add_bos:
unique_tokens.append(bos_symbol)
if add_eos:
unique_tokens.append(eos_symbol)
unique_tokens.extend(sorted(text_tokens))
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = unique_tokens
def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], []
for tokens in tokens_list:
assert all([True if s in self.token2idx else False for s in tokens]) is True
seq = (
([self.bos_symbol] if self.add_bos else [])
+ list(tokens)
+ ([self.eos_symbol] if self.add_eos else [])
)
seqs.append(seq)
seq_lens.append(len(seq))
max_len = max(seq_lens)
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
seq.extend([self.pad_symbol] * (max_len - seq_len))
tokens = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(seq_lens)
return tokens, tokens_lens
def __call__(self, text):
tokens_seq = [p for p in text]
seq = (
([self.bos_symbol] if self.add_bos else [])
+ tokens_seq
+ ([self.eos_symbol] if self.add_eos else [])
)
token_ids = [self.token2idx[token] for token in seq]
token_lens = len(tokens_seq) + self.add_eos + self.add_bos
return token_ids, token_lens
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollator(unique_tokens.symbols, add_bos=True, add_eos=True)
token2idx = collater.token2idx
return collater, token2idx
class phoneIDCollation:
def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None:
if cfg.preprocess.phone_extractor != "lexicon":
### get text token collator
if symbols_dict_file is None:
assert dataset is not None
symbols_dict_file = os.path.join(
cfg.preprocess.processed_dir, dataset, cfg.preprocess.symbols_dict
)
self.text_token_colloator, token2idx = get_text_token_collater(
symbols_dict_file
)
# # unique_tokens = SymbolTable.from_file(symbols_dict_path)
# # text_tokenizer = TextToken(unique_tokens.symbols, add_bos=True, add_eos=True)
# # update phone symbols dict file with pad_symbol or optional tokens (add_bos and add_eos) in TextTokenCollator
# phone_symbol_dict = SymbolTable()
# for s in sorted(list(set(token2idx.keys()))):
# phone_symbol_dict.add(s)
# phone_symbol_dict.to_file(symbols_dict_file)
def get_phone_id_sequence(self, cfg, phones_seq):
if cfg.preprocess.phone_extractor == "lexicon":
phones_seq = " ".join(phones_seq)
sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners)
else:
sequence, seq_len = self.text_token_colloator(phones_seq)
return sequence