File size: 5,348 Bytes
dfc143a 328b51c dfc143a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import json
import pathlib
from huggingface_hub import snapshot_download
import os
from os.path import join as join_path
from .omograph_model import OmographModel
from .accent_model import AccentModel
import re
class RUAccent:
def __init__(self, workdir=None, allow_cuda=True):
self.omograph_model = OmographModel(allow_cuda=allow_cuda)
self.accent_model = AccentModel(allow_cuda=allow_cuda)
if not workdir:
self.workdir = str(pathlib.Path(__file__).resolve().parent)
else:
self.workdir = workdir
def load(
self,
omograph_model_size="medium",
dict_load_startup=False,
disable_accent_dict=False,
repo="TeraTTS/accentuator",
):
if not os.path.exists(
join_path(self.workdir, "dictionary")
) or not os.path.exists(join_path(self.workdir, "nn")):
snapshot_download(
repo_id=repo,
ignore_patterns=["*.md", "*.gitattributes"],
local_dir=self.workdir,
local_dir_use_symlinks=False,
)
self.omographs = json.load(
open(join_path(self.workdir, "dictionary/omographs.json"), encoding='utf-8')
)
self.yo_words = json.load(
open(join_path(self.workdir, "dictionary/yo_words.json"), encoding='utf-8')
)
self.dict_load_startup = dict_load_startup
if dict_load_startup:
self.accents = json.load(
open(join_path(self.workdir, "dictionary/accents.json"), encoding='utf-8')
)
if disable_accent_dict:
self.accents = {}
self.disable_accent_dict = True
else:
self.disable_accent_dict = False
if omograph_model_size not in ["small", "medium"]:
raise NotImplementedError
self.omograph_model.load(
join_path(self.workdir, f"nn/nn_omograph/{omograph_model_size}/")
)
self.accent_model.load(join_path(self.workdir, "nn/nn_accent/"))
def split_by_words(self, string):
result = re.findall(r"\w*(?:\+\w+)*|[^\w\s]+", string.lower())
return [res for res in result if res]
def extract_initial_letters(self, text):
words = text
initial_letters = []
for word in words:
if len(word) > 2 and '+' not in word and not bool(re.search('[a-zA-Z]', word)):
initial_letters.append(word[0])
return initial_letters
def load_dict(self, text):
chars = self.extract_initial_letters(text)
out_dict = {}
for char in chars:
out_dict.update(
json.load(
open(
join_path(self.workdir, f"dictionary/letter_accent/{char}.json"),
encoding='utf-8'
)
)
)
return out_dict
def count_vowels(self, text):
vowels = "аеёиоуыэюяАЕЁИОУЫЭЮЯ"
return sum(1 for char in text if char in vowels)
def has_punctuation(self, text):
for char in text:
if char in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~":
return True
return False
def delete_spaces_before_punc(self, text):
punc = "!\"#$%&'()*,./:;<=>?@[\\]^_`{|}~"
for char in punc:
text = text.replace(" " + char, char)
return text
def process_yo(self, text):
splitted_text = text
for i, word in enumerate(splitted_text):
splitted_text[i] = self.yo_words.get(word, word)
return splitted_text
def process_omographs(self, text):
splitted_text = text
founded_omographs = []
for i, word in enumerate(splitted_text):
variants = self.omographs.get(word)
if variants:
founded_omographs.append(
{"word": word, "variants": variants, "position": i}
)
for omograph in founded_omographs:
splitted_text[
omograph["position"]
] = f"<w>{splitted_text[omograph['position']]}</w>"
cls = self.omograph_model.classify(
" ".join(splitted_text), omograph["variants"]
)
splitted_text[omograph["position"]] = cls
return splitted_text
def process_accent(self, text):
if not self.dict_load_startup and not self.disable_accent_dict:
self.accents = self.load_dict(text)
splitted_text = text
for i, word in enumerate(splitted_text):
stressed_word = self.accents.get(word, word)
if stressed_word == word and not self.has_punctuation(word) and self.count_vowels(word) > 1:
splitted_text[i] = self.accent_model.put_accent(word)
else:
splitted_text[i] = stressed_word
return splitted_text
def process_all(self, text):
text = self.split_by_words(text)
processed_text = self.process_yo(text)
processed_text = self.process_omographs(processed_text)
processed_text = self.process_accent(processed_text)
processed_text = " ".join(processed_text)
processed_text = self.delete_spaces_before_punc(processed_text)
return processed_text |