File size: 5,348 Bytes
dfc143a
 
 
 
 
 
 
 
 
 
 
328b51c
 
 
dfc143a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
import pathlib
from huggingface_hub import snapshot_download
import os
from os.path import join as join_path
from .omograph_model import OmographModel
from .accent_model import AccentModel
import re


class RUAccent:
    def __init__(self, workdir=None, allow_cuda=True):
        self.omograph_model = OmographModel(allow_cuda=allow_cuda)
        self.accent_model = AccentModel(allow_cuda=allow_cuda)
        if not workdir:
            self.workdir = str(pathlib.Path(__file__).resolve().parent)
        else:
            self.workdir = workdir

    def load(
        self,
        omograph_model_size="medium",
        dict_load_startup=False,
        disable_accent_dict=False,
        repo="TeraTTS/accentuator",
    ):
        if not os.path.exists(
            join_path(self.workdir, "dictionary")
        ) or not os.path.exists(join_path(self.workdir, "nn")):
            snapshot_download(
                repo_id=repo,
                ignore_patterns=["*.md", "*.gitattributes"],
                local_dir=self.workdir,
                local_dir_use_symlinks=False,
            )
        self.omographs = json.load(
            open(join_path(self.workdir, "dictionary/omographs.json"), encoding='utf-8')
        )
        self.yo_words = json.load(
            open(join_path(self.workdir, "dictionary/yo_words.json"), encoding='utf-8')
        )
        self.dict_load_startup = dict_load_startup

        if dict_load_startup:
            self.accents = json.load(
                open(join_path(self.workdir, "dictionary/accents.json"), encoding='utf-8')
            )
        if disable_accent_dict:
            self.accents = {}
            self.disable_accent_dict = True
        else:
            self.disable_accent_dict = False

        if omograph_model_size not in ["small", "medium"]:
            raise NotImplementedError

        self.omograph_model.load(
            join_path(self.workdir, f"nn/nn_omograph/{omograph_model_size}/")
        )
        self.accent_model.load(join_path(self.workdir, "nn/nn_accent/"))


    def split_by_words(self, string):
        result = re.findall(r"\w*(?:\+\w+)*|[^\w\s]+", string.lower())
        return [res for res in result if res]

    def extract_initial_letters(self, text):
        words = text
        initial_letters = []
        for word in words:
            if len(word) > 2 and '+' not in word and not bool(re.search('[a-zA-Z]', word)):
                initial_letters.append(word[0])
        return initial_letters

    def load_dict(self, text):
        chars = self.extract_initial_letters(text)
        out_dict = {}
        for char in chars:
            out_dict.update(
                json.load(
                    open(
                        join_path(self.workdir, f"dictionary/letter_accent/{char}.json"),
                        encoding='utf-8'
                    )
                )
            )
        return out_dict

    def count_vowels(self, text):
        vowels = "аеёиоуыэюяАЕЁИОУЫЭЮЯ"
        return sum(1 for char in text if char in vowels)

    def has_punctuation(self, text):
        for char in text:
            if char in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~":
                return True
        return False

    def delete_spaces_before_punc(self, text):
        punc = "!\"#$%&'()*,./:;<=>?@[\\]^_`{|}~"
        for char in punc:
            text = text.replace(" " + char, char)
        return text

    def process_yo(self, text):
        splitted_text = text

        for i, word in enumerate(splitted_text):
            splitted_text[i] = self.yo_words.get(word, word)
        return splitted_text

    def process_omographs(self, text):
        splitted_text = text

        founded_omographs = []
        for i, word in enumerate(splitted_text):
            variants = self.omographs.get(word)
            if variants:
                founded_omographs.append(
                    {"word": word, "variants": variants, "position": i}
                )
        for omograph in founded_omographs:
            splitted_text[
                omograph["position"]
            ] = f"<w>{splitted_text[omograph['position']]}</w>"
            cls = self.omograph_model.classify(
                " ".join(splitted_text), omograph["variants"]
            )
            splitted_text[omograph["position"]] = cls
        return splitted_text

    def process_accent(self, text):
        if not self.dict_load_startup and not self.disable_accent_dict:
            self.accents = self.load_dict(text)

        splitted_text = text

        for i, word in enumerate(splitted_text):
            stressed_word = self.accents.get(word, word)
            if stressed_word == word and not self.has_punctuation(word) and self.count_vowels(word) > 1:
                splitted_text[i] = self.accent_model.put_accent(word)
            else:
                splitted_text[i] = stressed_word
        return splitted_text

    def process_all(self, text):
        text = self.split_by_words(text)
        processed_text = self.process_yo(text)
        processed_text = self.process_omographs(processed_text)
        processed_text = self.process_accent(processed_text)
        processed_text = " ".join(processed_text)
        processed_text = self.delete_spaces_before_punc(processed_text)
        return processed_text