|
import torch |
|
|
|
from collections import deque |
|
from jamotools import split_syllables, join_jamos |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer |
|
|
|
class HangulTokenizerConfig(PretrainedConfig): |
|
model_type = "hangul_tokenizer" |
|
|
|
def __init__( |
|
self, |
|
base_tokenizer_name='unsloth/gemma-2-2b', |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.base_tokenizer_name = base_tokenizer_name |
|
|
|
|
|
class HangulTokenizer(PreTrainedModel): |
|
config_class = HangulTokenizerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.temp_module = torch.nn.Parameter(torch.ones(1)) |
|
self.base_tokenizer = AutoTokenizer.from_pretrained(config.base_tokenizer_name) |
|
self.base_tokenizer.pad_token_id = 128 |
|
self.base_tokenizer.pad_token = self.base_tokenizer.decode([self.base_tokenizer.pad_token_id]) |
|
self.space_token_id = self.base_tokenizer.encode(' ', add_special_tokens=False)[-1] |
|
char_start, char_end = 0xAC00, 0xD7A3 |
|
self.kor_chars = list(set([chr(code) for code in range(char_start, char_end + 1)])) |
|
self.char_3ids = [] |
|
self.char_1ids = [] |
|
for kor_char in self.kor_chars: |
|
ids = self.base_tokenizer.encode(kor_char, add_special_tokens=False) |
|
if len(ids)==3: |
|
self.char_3ids.append(ids) |
|
else: |
|
ids = ids+2*[self.base_tokenizer.pad_token_id] |
|
self.char_1ids.append(ids) |
|
self.chos = ['γ±', 'γ²', 'γ΄', 'γ·', 'γΈ', 'γΉ', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
'] |
|
self.joongs = ['γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
‘', 'γ
’', 'γ
£'] |
|
self.jongs = [self.base_tokenizer.pad_token, 'γ±', 'γ²', 'γ³', 'γ΄', 'γ΅', 'γΆ', 'γ·', 'γΉ', 'γΊ', 'γ»', 'γΌ', 'γ½', 'γΎ', 'γΏ', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
'] |
|
jamos = list(set(self.chos) | set(self.joongs) | set(self.jongs)) |
|
jamo_ids = self.base_tokenizer(jamos, add_special_tokens=False)['input_ids'] |
|
self.jamo_to_id = {jamo: jamo_id[-1] for jamo, jamo_id in zip(jamos, jamo_ids)} |
|
self.cho_ids = [self.jamo_to_id[cho] for cho in self.chos] |
|
self.joong_ids = [self.jamo_to_id[joong] for joong in self.joongs] |
|
self.jong_ids = [self.jamo_to_id[jong] for jong in self.jongs] |
|
self.id_to_jamo = {jamo_id: jamo for jamo, jamo_id in self.jamo_to_id.items()} |
|
|
|
def encode_jamo(self, sentence): |
|
encoded_ids = [] |
|
token_type_ids = [] |
|
past_chars = '' |
|
for char in sentence: |
|
if char in self.kor_chars: |
|
if past_chars: |
|
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
|
encoded_ids.extend(past_chars_encoded) |
|
token_type_ids.extend([0]*len(past_chars_encoded)) |
|
past_chars='' |
|
char_splitted = list(split_syllables(char))[:3] |
|
char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] |
|
cho, joong, jong = char_splitted |
|
encoded_ids.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) |
|
token_type_ids.extend([1,2,3]) |
|
else: |
|
past_chars = past_chars+char |
|
if past_chars: |
|
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
|
encoded_ids.extend(past_chars_encoded) |
|
token_type_ids.extend([0]*len(past_chars_encoded)) |
|
return encoded_ids, token_type_ids |
|
|
|
def decode_jamo(self, encoded_ids, token_type_ids): |
|
encoded_ids = deque(encoded_ids) |
|
token_type_ids = deque(token_type_ids) |
|
decoded = [] |
|
past_ids = [] |
|
while len(encoded_ids): |
|
encoded_id = encoded_ids.popleft() |
|
token_type_id = token_type_ids.popleft() |
|
if token_type_id==0: |
|
past_ids.append(encoded_id) |
|
else: |
|
decoded.append(self.base_tokenizer.decode(past_ids)) |
|
past_ids = [] |
|
cho_id = encoded_id |
|
joong_id = encoded_ids.popleft() |
|
jong_id = encoded_ids.popleft() |
|
token_type_ids.popleft() |
|
token_type_ids.popleft() |
|
char = join_jamos([self.id_to_jamo[cho_id], self.id_to_jamo[joong_id], self.id_to_jamo[jong_id]])[:1] |
|
decoded.append(char) |
|
decoded.append(self.base_tokenizer.decode(past_ids)) |
|
return ''.join(decoded) |
|
|
|
def encode_char(self, sentence): |
|
encoded_ids = [] |
|
token_type_ids = [] |
|
past_chars = '' |
|
for char in sentence: |
|
if char in self.kor_chars: |
|
if past_chars: |
|
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
|
encoded_ids.extend(past_chars_encoded) |
|
token_type_ids.extend([0]*len(past_chars_encoded)) |
|
past_chars='' |
|
encoded_id = self.base_tokenizer.encode(char, add_special_tokens=False) |
|
encoded_id = encoded_id + (3-len(encoded_id)) * [self.base_tokenizer.pad_token_id] |
|
encoded_ids.extend(encoded_id) |
|
token_type_ids.extend([4,4,4]) |
|
else: |
|
past_chars = past_chars+char |
|
if past_chars: |
|
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
|
encoded_ids.extend(past_chars_encoded) |
|
token_type_ids.extend([0]*len(past_chars_encoded)) |
|
return encoded_ids, token_type_ids |
|
|
|
def decode_char(self, encoded_ids, token_type_ids): |
|
encoded_ids = deque(encoded_ids) |
|
token_type_ids = deque(token_type_ids) |
|
decoded = [] |
|
past_ids = [] |
|
while len(encoded_ids): |
|
encoded_id = encoded_ids.popleft() |
|
token_type_id = token_type_ids.popleft() |
|
if token_type_id==0: |
|
past_ids.append(encoded_id) |
|
else: |
|
decoded.append(self.base_tokenizer.decode(past_ids)) |
|
past_ids = [] |
|
id1 = encoded_id |
|
id2 = encoded_ids.popleft() |
|
id3 = encoded_ids.popleft() |
|
token_type_ids.popleft() |
|
token_type_ids.popleft() |
|
[id1, id2, id3] |
|
char = self.base_tokenizer.decode([id1, id2, id3])[:1] |
|
decoded.append(char) |
|
decoded.append(self.base_tokenizer.decode(past_ids)) |
|
return ''.join(decoded) |
|
|
|
def encode_jamo_from_char_encoded(self, encoded_ids, token_type_ids): |
|
encoded_ids = deque(encoded_ids) |
|
token_type_ids = deque(token_type_ids) |
|
encoded_ids_new = [] |
|
token_type_ids_new = [] |
|
while len(encoded_ids): |
|
encoded_id = encoded_ids.popleft() |
|
token_type_id = token_type_ids.popleft() |
|
if token_type_id==0: |
|
encoded_ids_new.append(encoded_id) |
|
token_type_ids_new.append(token_type_id) |
|
else: |
|
encoded_id2 = encoded_ids.popleft() |
|
encoded_id3 = encoded_ids.popleft() |
|
token_type_ids.popleft() |
|
token_type_ids.popleft() |
|
char = self.base_tokenizer.decode([encoded_id, encoded_id2, encoded_id3])[0] |
|
char_splitted = list(split_syllables(char))[:3] |
|
char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] |
|
cho, joong, jong = char_splitted |
|
encoded_ids_new.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) |
|
token_type_ids_new.extend([1,2,3]) |
|
return encoded_ids_new, token_type_ids_new |
|
|
|
def batch_encode_char(self, sentences): |
|
input_ids = [] |
|
attention_mask = [] |
|
token_type_ids = [] |
|
for sentence in sentences: |
|
input_ids_row, token_type_id = self.encode_char(sentence) |
|
input_ids.append(input_ids_row) |
|
token_type_ids.append(token_type_id) |
|
max_length = max(list(map(len, input_ids))) |
|
for i in range(len(sentences)): |
|
input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] |
|
attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) |
|
token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] |
|
return ( |
|
torch.LongTensor(input_ids), |
|
torch.LongTensor(attention_mask), |
|
torch.LongTensor(token_type_ids) |
|
) |
|
|
|
def batch_encode_jamo_from_char_encoded(self, batch_encoded_ids, batch_token_type_ids): |
|
input_ids = [] |
|
attention_mask = [] |
|
token_type_ids_new = [] |
|
for encoded_ids, token_type_ids in zip(batch_encoded_ids, batch_token_type_ids): |
|
encoded_ids_row, token_type_ids_row = self.encode_jamo_from_char_encoded(encoded_ids, token_type_ids) |
|
attention_mask.append([1 if encoded_id!=self.base_tokenizer.eos_token_id else 0 for encoded_id in encoded_ids_row]) |
|
input_ids.append(encoded_ids_row) |
|
token_type_ids_new.append(token_type_ids_row) |
|
|
|
return ( |
|
torch.LongTensor(input_ids), |
|
torch.LongTensor(attention_mask), |
|
torch.LongTensor(token_type_ids_new), |
|
) |
|
|
|
def batch_encode_jamo(self, sentences): |
|
input_ids = [] |
|
attention_mask = [] |
|
token_type_ids = [] |
|
for sentence in sentences: |
|
input_ids_row, token_type_id = self.encode_jamo(sentence) |
|
input_ids.append(input_ids_row) |
|
token_type_ids.append(token_type_id) |
|
max_length = max(list(map(len, input_ids))) |
|
|
|
for i in range(len(sentences)): |
|
input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] |
|
attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) |
|
token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] |
|
|
|
return ( |
|
torch.LongTensor(input_ids), |
|
torch.LongTensor(attention_mask), |
|
torch.LongTensor(token_type_ids), |
|
) |
|
|