import json
import logging
import os
from pathlib import Path
import re
from transformers import SpeechT5Tokenizer
from transformers.models.speecht5.tokenization_speecht5 import (
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES,
)
from itertools import chain
from typing import List, Optional, Tuple
logger = logging.getLogger(__name__)
NP_CHARCTERS = " !\"#$%&'()=~|`{+*}<>?_-^\\@[;:],./ !”#$%&’()=~|`{+*}<>?_ー^¥@「;:」、。・`"
def _g2p_with_np(text: str, np_lsit: str) -> List[str]:
from pyopenjtalk import g2p
np_pattern = re.compile(f"([{re.escape(np_lsit)}])")
return list(
chain.from_iterable(
[
(text,) if text in np_lsit else g2p(text, kana=False, join=False)
for text in np_pattern.split(text)
if len(text) > 0
]
)
)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"esnya/japanese_speecht5_tts": "https://huggingface.co/esnya/japanese_speecht5_tts/resolve/main/vocab.json",
},
}
class SpeechT5OpenjtalkTokenizer(SpeechT5Tokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
bos_token: str = "",
eos_token: str = "",
unk_token: str = "",
pad_token: str = "",
non_phenome_characters: str = NP_CHARCTERS,
**kwargs,
):
try:
super().__init__(
vocab_file=None,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
except TypeError:
pass
self.non_phenome_characters = non_phenome_characters
self.vocab_file = vocab_file
self._load_vocab()
def _load_vocab(self):
if isinstance(self.vocab_file, str) and self.vocab_file.endswith(".json"):
with open(self.vocab_file, encoding="utf-8") as f:
self.label2id = json.load(f)
self.id2label = {v: k for k, v in self.label2id.items()}
@property
def bos_token_id(self) -> int | None:
return super().bos_token_id
@property
def vocab_size(self):
return len(self.label2id)
def get_vocab(self):
return self.label2id
def __getstate__(self):
state = super().__getstate__()
del state["sp_model"]
return state
def __setstate__(self, d):
self.__dict__ = d
self._load_vocab()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
):
if filename_prefix is None:
filename_prefix = ".json"
save_path = Path(save_directory)
if not save_path.is_dir():
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_path = Path(save_directory) / Path(f"vocab{filename_prefix}")
vocab_path.parent.mkdir(parents=True, exist_ok=True)
with open(vocab_path, "w", encoding="utf-8") as f:
json.dump(self.label2id, f, ensure_ascii=False, indent=2)
return (str(vocab_path),)
def _tokenize(self, text: str) -> List[str]:
return _g2p_with_np(text, self.non_phenome_characters)
def _convert_token_to_id(self, token):
return self.label2id.get(token, self.label2id.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.id2label.get(index, self.unk_token)