mrasp / tokenization_bat.py
thehonestbob's picture
Upload 10 files
259a0a5
# -*- coding: utf-8 -*-
"""
@author:cb
@contact:[email protected]
@time:2023/5/30 14:21
@filename:tokenization.py
@software:PyCharm
@description:
"""
import re
from transformers import FSMTTokenizer as fsmt
class FSMTTokenizer(fsmt):
def __init__(self, *args, **kwargs):
super(FSMTTokenizer, self).__init__(*args, **kwargs)
self.space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*')
self.reversal = False
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
return self.cache_moses_tokenizer[lang].tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=False
)
def _switch_to_input_mode(self):
if self.reversal:
self.lang_prefix, self.lang_prefix_id = 'zh', 64870
else:
self.lang_prefix, self.lang_prefix_id = 'en', 64812
def _switch_to_target_mode(self):
if self.reversal:
self.lang_prefix, self.lang_prefix_id = 'en', 64812
else:
self.lang_prefix, self.lang_prefix_id = 'zh', 64870
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A FAIRSEQ Transformer sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
sep = [self.sep_token_id]
token_ids_0 = [self.lang_prefix_id] + token_ids_0
# no bos used in fairseq
if token_ids_1 is None:
return token_ids_0 + sep
return token_ids_0 + sep + token_ids_1 + sep
def moses_pipeline(self, text, lang):
text = self.moses_punct_norm(text, lang)
return text
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
"""
原版FSMTTokenizer会把中文标点英文化,故重写
:param text:
:param lang:
:param bypass_tokenizer:
:return:
"""
if self.do_lower_case:
text = text.lower()
if bypass_tokenizer:
text = text.split()
else:
text = self.moses_pipeline(text, lang=self.lang_prefix)
text = self.moses_tokenize(text, lang=self.lang_prefix)
split_tokens = []
for token in text:
if token:
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
def convert_tokens_to_string(self, tokens):
"""
删除非英文字母前后的空格,业务上处理更合适
:param tokens:
:return:
"""
tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens)
tokens = self.space_re.sub('', tokens)
return tokens
if __name__ == '__main__':
tokenizer = FSMTTokenizer.from_pretrained(r'./')
r = tokenizer(['hello'], text_target=['你好朋友'])
print(r)
tokenizer.reversal = True
r = tokenizer(['你好朋友'], text_target=['hello'])
# # r['input_ids'] += r['labels']
# # r['labels'] += r['input_ids']
print(r)