|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Credits |
|
This code is modified from https://github.com/GitYCC/g2pW |
|
""" |
|
import os |
|
import re |
|
|
|
|
|
def wordize_and_map(text: str): |
|
words = [] |
|
index_map_from_text_to_word = [] |
|
index_map_from_word_to_text = [] |
|
while len(text) > 0: |
|
match_space = re.match(r'^ +', text) |
|
if match_space: |
|
space_str = match_space.group(0) |
|
index_map_from_text_to_word += [None] * len(space_str) |
|
text = text[len(space_str):] |
|
continue |
|
|
|
match_en = re.match(r'^[a-zA-Z0-9]+', text) |
|
if match_en: |
|
en_word = match_en.group(0) |
|
|
|
word_start_pos = len(index_map_from_text_to_word) |
|
word_end_pos = word_start_pos + len(en_word) |
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos)) |
|
|
|
index_map_from_text_to_word += [len(words)] * len(en_word) |
|
|
|
words.append(en_word) |
|
text = text[len(en_word):] |
|
else: |
|
word_start_pos = len(index_map_from_text_to_word) |
|
word_end_pos = word_start_pos + 1 |
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos)) |
|
|
|
index_map_from_text_to_word += [len(words)] |
|
|
|
words.append(text[0]) |
|
text = text[1:] |
|
return words, index_map_from_text_to_word, index_map_from_word_to_text |
|
|
|
|
|
def tokenize_and_map(tokenizer, text: str): |
|
words, text2word, word2text = wordize_and_map(text=text) |
|
|
|
tokens = [] |
|
index_map_from_token_to_text = [] |
|
for word, (word_start, word_end) in zip(words, word2text): |
|
word_tokens = tokenizer.tokenize(word) |
|
|
|
if len(word_tokens) == 0 or word_tokens == ['[UNK]']: |
|
index_map_from_token_to_text.append((word_start, word_end)) |
|
tokens.append('[UNK]') |
|
else: |
|
current_word_start = word_start |
|
for word_token in word_tokens: |
|
word_token_len = len(re.sub(r'^##', '', word_token)) |
|
index_map_from_token_to_text.append( |
|
(current_word_start, current_word_start + word_token_len)) |
|
current_word_start = current_word_start + word_token_len |
|
tokens.append(word_token) |
|
|
|
index_map_from_text_to_token = text2word |
|
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text): |
|
for token_pos in range(token_start, token_end): |
|
index_map_from_text_to_token[token_pos] = i |
|
|
|
return tokens, index_map_from_text_to_token, index_map_from_token_to_text |
|
|
|
|
|
def _load_config(config_path: os.PathLike): |
|
import importlib.util |
|
spec = importlib.util.spec_from_file_location('__init__', config_path) |
|
config = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(config) |
|
return config |
|
|
|
|
|
default_config_dict = { |
|
'manual_seed': 1313, |
|
'model_source': 'bert-base-chinese', |
|
'window_size': 32, |
|
'num_workers': 2, |
|
'use_mask': True, |
|
'use_char_phoneme': False, |
|
'use_conditional': True, |
|
'param_conditional': { |
|
'affect_location': 'softmax', |
|
'bias': True, |
|
'char-linear': True, |
|
'pos-linear': False, |
|
'char+pos-second': True, |
|
'char+pos-second_lowrank': False, |
|
'lowrank_size': 0, |
|
'char+pos-second_fm': False, |
|
'fm_size': 0, |
|
'fix_mode': None, |
|
'count_json': 'train.count.json' |
|
}, |
|
'lr': 5e-5, |
|
'val_interval': 200, |
|
'num_iter': 10000, |
|
'use_focal': False, |
|
'param_focal': { |
|
'alpha': 0.0, |
|
'gamma': 0.7 |
|
}, |
|
'use_pos': True, |
|
'param_pos ': { |
|
'weight': 0.1, |
|
'pos_joint_training': True, |
|
'train_pos_path': 'train.pos', |
|
'valid_pos_path': 'dev.pos', |
|
'test_pos_path': 'test.pos' |
|
} |
|
} |
|
|
|
|
|
def load_config(config_path: os.PathLike, use_default: bool=False): |
|
config = _load_config(config_path) |
|
if use_default: |
|
for attr, val in default_config_dict.items(): |
|
if not hasattr(config, attr): |
|
setattr(config, attr, val) |
|
elif isinstance(val, dict): |
|
d = getattr(config, attr) |
|
for dict_k, dict_v in val.items(): |
|
if dict_k not in d: |
|
d[dict_k] = dict_v |
|
return config |
|
|