gemma2-2b-kor-deobfuscation / modeling_hangul_tokenizer.py
jwengr's picture
Upload folder using huggingface_hub
315c544 verified
import torch
from collections import deque
from jamotools import split_syllables, join_jamos
from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer
class HangulTokenizerConfig(PretrainedConfig):
model_type = "hangul_tokenizer"
def __init__(
self,
base_tokenizer_name='unsloth/gemma-2-2b',
**kwargs
):
super().__init__(**kwargs)
self.base_tokenizer_name = base_tokenizer_name
class HangulTokenizer(PreTrainedModel):
config_class = HangulTokenizerConfig
def __init__(self, config):
super().__init__(config)
self.temp_module = torch.nn.Parameter(torch.ones(1))
self.base_tokenizer = AutoTokenizer.from_pretrained(config.base_tokenizer_name)
self.base_tokenizer.pad_token_id = 128
self.base_tokenizer.pad_token = self.base_tokenizer.decode([self.base_tokenizer.pad_token_id])
self.space_token_id = self.base_tokenizer.encode(' ', add_special_tokens=False)[-1]
char_start, char_end = 0xAC00, 0xD7A3 # κ°€-힣
self.kor_chars = list(set([chr(code) for code in range(char_start, char_end + 1)]))
self.char_3ids = []
self.char_1ids = []
for kor_char in self.kor_chars:
ids = self.base_tokenizer.encode(kor_char, add_special_tokens=False)
if len(ids)==3:
self.char_3ids.append(ids)
else:
ids = ids+2*[self.base_tokenizer.pad_token_id]
self.char_1ids.append(ids)
self.chos = ['γ„±', 'γ„²', 'γ„΄', 'γ„·', 'γ„Έ', 'γ„Ή', 'ㅁ', 'γ…‚', 'γ…ƒ', 'γ……', 'γ…†', 'γ…‡', 'γ…ˆ', 'γ…‰', 'γ…Š', 'γ…‹', 'γ…Œ', 'ㅍ', 'γ…Ž']
self.joongs = ['ㅏ', 'ㅐ', 'γ…‘', 'γ…’', 'γ…“', 'γ…”', 'γ…•', 'γ…–', 'γ…—', 'γ…˜', 'γ…™', 'γ…š', 'γ…›', 'γ…œ', 'ㅝ', 'γ…ž', 'γ…Ÿ', 'γ… ', 'γ…‘', 'γ…’', 'γ…£']
self.jongs = [self.base_tokenizer.pad_token, 'γ„±', 'γ„²', 'γ„³', 'γ„΄', 'γ„΅', 'γ„Ά', 'γ„·', 'γ„Ή', 'γ„Ί', 'γ„»', 'γ„Ό', 'γ„½', 'γ„Ύ', 'γ„Ώ', 'γ…€', 'ㅁ', 'γ…‚', 'γ…„', 'γ……', 'γ…†', 'γ…‡', 'γ…ˆ', 'γ…Š', 'γ…‹', 'γ…Œ', 'ㅍ', 'γ…Ž']
jamos = list(set(self.chos) | set(self.joongs) | set(self.jongs))
jamo_ids = self.base_tokenizer(jamos, add_special_tokens=False)['input_ids']
self.jamo_to_id = {jamo: jamo_id[-1] for jamo, jamo_id in zip(jamos, jamo_ids)}
self.cho_ids = [self.jamo_to_id[cho] for cho in self.chos]
self.joong_ids = [self.jamo_to_id[joong] for joong in self.joongs]
self.jong_ids = [self.jamo_to_id[jong] for jong in self.jongs]
self.id_to_jamo = {jamo_id: jamo for jamo, jamo_id in self.jamo_to_id.items()}
def encode_jamo(self, sentence):
encoded_ids = []
token_type_ids = []
past_chars = ''
for char in sentence:
if char in self.kor_chars:
if past_chars:
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False)
encoded_ids.extend(past_chars_encoded)
token_type_ids.extend([0]*len(past_chars_encoded))
past_chars=''
char_splitted = list(split_syllables(char))[:3]
char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token]
cho, joong, jong = char_splitted
encoded_ids.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]])
token_type_ids.extend([1,2,3])
else:
past_chars = past_chars+char
if past_chars:
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False)
encoded_ids.extend(past_chars_encoded)
token_type_ids.extend([0]*len(past_chars_encoded))
return encoded_ids, token_type_ids
def decode_jamo(self, encoded_ids, token_type_ids):
encoded_ids = deque(encoded_ids)
token_type_ids = deque(token_type_ids)
decoded = []
past_ids = []
while len(encoded_ids):
encoded_id = encoded_ids.popleft()
token_type_id = token_type_ids.popleft()
if token_type_id==0:
past_ids.append(encoded_id)
else:
decoded.append(self.base_tokenizer.decode(past_ids))
past_ids = []
cho_id = encoded_id
joong_id = encoded_ids.popleft()
jong_id = encoded_ids.popleft()
token_type_ids.popleft()
token_type_ids.popleft()
char = join_jamos([self.id_to_jamo[cho_id], self.id_to_jamo[joong_id], self.id_to_jamo[jong_id]])[:1]
decoded.append(char)
decoded.append(self.base_tokenizer.decode(past_ids))
return ''.join(decoded)
def encode_char(self, sentence):
encoded_ids = []
token_type_ids = []
past_chars = ''
for char in sentence:
if char in self.kor_chars:
if past_chars:
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False)
encoded_ids.extend(past_chars_encoded)
token_type_ids.extend([0]*len(past_chars_encoded))
past_chars=''
encoded_id = self.base_tokenizer.encode(char, add_special_tokens=False)
encoded_id = encoded_id + (3-len(encoded_id)) * [self.base_tokenizer.pad_token_id]
encoded_ids.extend(encoded_id)
token_type_ids.extend([4,4,4])
else:
past_chars = past_chars+char
if past_chars:
past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False)
encoded_ids.extend(past_chars_encoded)
token_type_ids.extend([0]*len(past_chars_encoded))
return encoded_ids, token_type_ids
def decode_char(self, encoded_ids, token_type_ids):
encoded_ids = deque(encoded_ids)
token_type_ids = deque(token_type_ids)
decoded = []
past_ids = []
while len(encoded_ids):
encoded_id = encoded_ids.popleft()
token_type_id = token_type_ids.popleft()
if token_type_id==0:
past_ids.append(encoded_id)
else:
decoded.append(self.base_tokenizer.decode(past_ids))
past_ids = []
id1 = encoded_id
id2 = encoded_ids.popleft()
id3 = encoded_ids.popleft()
token_type_ids.popleft()
token_type_ids.popleft()
[id1, id2, id3]
char = self.base_tokenizer.decode([id1, id2, id3])[:1]
decoded.append(char)
decoded.append(self.base_tokenizer.decode(past_ids))
return ''.join(decoded)
def encode_jamo_from_char_encoded(self, encoded_ids, token_type_ids):
encoded_ids = deque(encoded_ids)
token_type_ids = deque(token_type_ids)
encoded_ids_new = []
token_type_ids_new = []
while len(encoded_ids):
encoded_id = encoded_ids.popleft()
token_type_id = token_type_ids.popleft()
if token_type_id==0:
encoded_ids_new.append(encoded_id)
token_type_ids_new.append(token_type_id)
else:
encoded_id2 = encoded_ids.popleft()
encoded_id3 = encoded_ids.popleft()
token_type_ids.popleft()
token_type_ids.popleft()
char = self.base_tokenizer.decode([encoded_id, encoded_id2, encoded_id3])[0]
char_splitted = list(split_syllables(char))[:3]
char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token]
cho, joong, jong = char_splitted
encoded_ids_new.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]])
token_type_ids_new.extend([1,2,3])
return encoded_ids_new, token_type_ids_new
def batch_encode_char(self, sentences):
input_ids = []
attention_mask = []
token_type_ids = []
for sentence in sentences:
input_ids_row, token_type_id = self.encode_char(sentence)
input_ids.append(input_ids_row)
token_type_ids.append(token_type_id)
max_length = max(list(map(len, input_ids)))
for i in range(len(sentences)):
input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id]
attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]])
token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0]
return (
torch.LongTensor(input_ids),
torch.LongTensor(attention_mask),
torch.LongTensor(token_type_ids)
)
def batch_encode_jamo_from_char_encoded(self, batch_encoded_ids, batch_token_type_ids):
input_ids = []
attention_mask = []
token_type_ids_new = []
for encoded_ids, token_type_ids in zip(batch_encoded_ids, batch_token_type_ids):
encoded_ids_row, token_type_ids_row = self.encode_jamo_from_char_encoded(encoded_ids, token_type_ids)
attention_mask.append([1 if encoded_id!=self.base_tokenizer.eos_token_id else 0 for encoded_id in encoded_ids_row])
input_ids.append(encoded_ids_row)
token_type_ids_new.append(token_type_ids_row)
return (
torch.LongTensor(input_ids),
torch.LongTensor(attention_mask),
torch.LongTensor(token_type_ids_new),
)
def batch_encode_jamo(self, sentences):
input_ids = []
attention_mask = []
token_type_ids = []
for sentence in sentences:
input_ids_row, token_type_id = self.encode_jamo(sentence)
input_ids.append(input_ids_row)
token_type_ids.append(token_type_id)
max_length = max(list(map(len, input_ids)))
for i in range(len(sentences)):
input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id]
attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]])
token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0]
return (
torch.LongTensor(input_ids),
torch.LongTensor(attention_mask),
torch.LongTensor(token_type_ids),
)