|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from modules.file import ExcelFileWriter |
|
import os |
|
|
|
from abc import ABC, abstractmethod |
|
from typing import List |
|
import re |
|
|
|
class FilterPipeline(): |
|
def __init__(self, filter_list): |
|
self._filter_list:List[Filter] = filter_list |
|
|
|
def append(self, filter): |
|
self._filter_list.append(filter) |
|
|
|
def batch_encoder(self, inputs): |
|
for filter in self._filter_list: |
|
inputs = filter.encoder(inputs) |
|
return inputs |
|
|
|
def batch_decoder(self, inputs): |
|
for filter in reversed(self._filter_list): |
|
inputs = filter.decoder(inputs) |
|
return inputs |
|
|
|
class Filter(ABC): |
|
|
|
def __init__(self): |
|
self.name = 'filter' |
|
self.code = [] |
|
@abstractmethod |
|
def encoder(self, inputs): |
|
|
|
pass |
|
|
|
@abstractmethod |
|
def decoder(self, inputs): |
|
|
|
pass |
|
|
|
class SpecialTokenFilter(Filter): |
|
|
|
def __init__(self): |
|
self.name = 'special token filter' |
|
self.code = [] |
|
self.special_tokens = ['!', '!', '-'] |
|
|
|
def encoder(self, inputs): |
|
|
|
filtered_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if not all(char in self.special_tokens for char in input_str): |
|
filtered_inputs.append(input_str) |
|
else: |
|
self.code.append([i, input_str]) |
|
return filtered_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
original_inputs = inputs.copy() |
|
for removed_indice in self.code: |
|
original_inputs.insert(removed_indice[0], removed_indice[1]) |
|
return original_inputs |
|
|
|
class SperSignFilter(Filter): |
|
|
|
def __init__(self): |
|
self.name = 's percentage sign filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
|
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if '%s' in input_str: |
|
encoded_str = input_str.replace('%s', '*') |
|
self.code.append(i) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
decoded_inputs = inputs.copy() |
|
for i in self.code: |
|
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') |
|
return decoded_inputs |
|
|
|
class ParenSParenFilter(Filter): |
|
|
|
def __init__(self): |
|
self.name = 'Paren s paren filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
|
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, input_str in enumerate(inputs): |
|
if '(s)' in input_str: |
|
encoded_str = input_str.replace('(s)', '$') |
|
self.code.append(i) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
decoded_inputs = inputs.copy() |
|
for i in self.code: |
|
decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)') |
|
return decoded_inputs |
|
|
|
class ChevronsFilter(Filter): |
|
|
|
def __init__(self): |
|
self.name = 'chevrons filter' |
|
self.code = [] |
|
|
|
def encoder(self, inputs): |
|
|
|
encoded_inputs = [] |
|
self.code = [] |
|
pattern = re.compile(r'<.*?>') |
|
for i, input_str in enumerate(inputs): |
|
if pattern.search(input_str): |
|
matches = pattern.findall(input_str) |
|
encoded_str = pattern.sub('#', input_str) |
|
self.code.append((i, matches)) |
|
else: |
|
encoded_str = input_str |
|
encoded_inputs.append(encoded_str) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
decoded_inputs = inputs.copy() |
|
for i, matches in self.code: |
|
for match in matches: |
|
decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1) |
|
return decoded_inputs |
|
|
|
class SimilarFilter(Filter): |
|
|
|
def __init__(self): |
|
self.name = 'similar filter' |
|
self.code = [] |
|
|
|
def is_similar(self, str1, str2): |
|
|
|
pattern = re.compile(r'\d+') |
|
return pattern.sub('', str1) == pattern.sub('', str2) |
|
|
|
def encoder(self, inputs): |
|
|
|
encoded_inputs = [] |
|
self.code = [] |
|
i = 0 |
|
while i < len(inputs): |
|
encoded_inputs.append(inputs[i]) |
|
similar_strs = [inputs[i]] |
|
j = i + 1 |
|
while j < len(inputs) and self.is_similar(inputs[i], inputs[j]): |
|
similar_strs.append(inputs[j]) |
|
j += 1 |
|
if len(similar_strs) > 1: |
|
self.code.append((i, similar_strs)) |
|
i = j |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
decoded_inputs = inputs |
|
for i, similar_strs in self.code: |
|
pattern = re.compile(r'\d+') |
|
for j in range(len(similar_strs)): |
|
if pattern.search(similar_strs[j]): |
|
number = re.findall(r'\d+', similar_strs[j])[0] |
|
new_str = pattern.sub(number, inputs[i]) |
|
else: |
|
new_str = inputs[i] |
|
if j > 0: |
|
decoded_inputs.insert(i + j, new_str) |
|
return decoded_inputs |
|
|
|
class ChineseFilter: |
|
|
|
def __init__(self, pinyin_lib_file='pinyin.txt'): |
|
self.name = 'chinese filter' |
|
self.code = [] |
|
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) |
|
|
|
def load_pinyin_lib(self, file_path): |
|
|
|
with open(os.path.join(script_dir, file_path), 'r', encoding='utf-8') as f: |
|
return set(line.strip().lower() for line in f) |
|
|
|
def is_valid_chinese(self, word): |
|
|
|
if len(word.split()) == 1 and word[0].isupper(): |
|
return self.is_pinyin(word.lower()) |
|
return False |
|
|
|
def encoder(self, inputs): |
|
|
|
encoded_inputs = [] |
|
self.code = [] |
|
for i, word in enumerate(inputs): |
|
if self.is_valid_chinese(word): |
|
self.code.append((i, word)) |
|
else: |
|
encoded_inputs.append(word) |
|
return encoded_inputs |
|
|
|
def decoder(self, inputs): |
|
|
|
decoded_inputs = inputs.copy() |
|
for i, word in self.code: |
|
decoded_inputs.insert(i, word) |
|
return decoded_inputs |
|
|
|
def is_pinyin(self, string): |
|
|
|
string = string.lower() |
|
stringlen = len(string) |
|
max_len = 6 |
|
result = [] |
|
n = 0 |
|
while n < stringlen: |
|
matched = 0 |
|
temp_result = [] |
|
for i in range(max_len, 0, -1): |
|
s = string[0:i] |
|
if s in self.pinyin_lib: |
|
temp_result.append(string[:i]) |
|
matched = i |
|
break |
|
if i == 1 and len(temp_result) == 0: |
|
return False |
|
result.extend(temp_result) |
|
string = string[matched:] |
|
n += matched |
|
return True |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir))) |
|
|
|
|
|
class Model(): |
|
def __init__(self, modelname, selected_lora_model, selected_gpu): |
|
def get_gpu_index(gpu_info, target_gpu_name): |
|
""" |
|
从 GPU 信息中获取目标 GPU 的索引 |
|
Args: |
|
gpu_info (list): 包含 GPU 名称的列表 |
|
target_gpu_name (str): 目标 GPU 的名称 |
|
|
|
Returns: |
|
int: 目标 GPU 的索引,如果未找到则返回 -1 |
|
""" |
|
for i, name in enumerate(gpu_info): |
|
if target_gpu_name.lower() in name.lower(): |
|
return i |
|
return -1 |
|
if selected_gpu != "cpu": |
|
gpu_count = torch.cuda.device_count() |
|
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)] |
|
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu) |
|
self.device_name = f"cuda:{selected_gpu_index}" |
|
else: |
|
self.device_name = "cpu" |
|
print("device_name", self.device_name) |
|
self.model = AutoModelForCausalLM.from_pretrained(modelname, torch_dtype="auto").to(self.device_name) |
|
self.tokenizer = AutoTokenizer.from_pretrained(modelname) |
|
|
|
|
|
def generate(self, inputs, original_language, target_languages, max_batch_size): |
|
filter_list = [SpecialTokenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()] |
|
filter_pipeline = FilterPipeline(filter_list) |
|
def process_gpu_translate_result(temp_outputs): |
|
outputs = [] |
|
for temp_output in temp_outputs: |
|
length = len(temp_output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in temp_output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
excel_writer = ExcelFileWriter() |
|
excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs)) |
|
if self.device_name == "cpu": |
|
|
|
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name) |
|
output = [] |
|
for target_language in target_languages: |
|
|
|
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)] |
|
|
|
generated_tokens = self.model.generate( |
|
**input_ids, |
|
forced_bos_token_id=target_lang_code, |
|
max_length=128 |
|
) |
|
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
output.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
outputs = [] |
|
length = len(output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |
|
else: |
|
|
|
|
|
|
|
print("length of inputs: ",len(inputs)) |
|
batch_size = min(len(inputs), int(max_batch_size)) |
|
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)] |
|
print("length of batches size: ", len(batches)) |
|
temp_outputs = [] |
|
processed_num = 0 |
|
for index, batch in enumerate(batches): |
|
|
|
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") |
|
print(len(batch)) |
|
print(batch) |
|
batch = filter_pipeline.batch_encoder(batch) |
|
print(batch) |
|
temp = [] |
|
if len(batch) > 0: |
|
for target_language in target_languages: |
|
batch_messages = [[ |
|
{"role": "system", "content": f"You are an expert in translating {original_language} to {target_language} for ERP systems. Your task is to translate markdown-formatted text from {original_language} to {target_language}. The text to be translated may not necessarily be complete phrases or sentences, but you must translate it into the corresponding language based on your own understanding, preserving its formatting without adding extra content."}, |
|
{"role": "user", "content": input}, |
|
] for input in batch] |
|
batch_texts = [self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) for messages in batch_messages] |
|
self.tokenizer.padding_side = "left" |
|
model_inputs = self.tokenizer( |
|
batch_texts, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
).to(self.device_name) |
|
generated_ids = self.model.generate( |
|
max_new_tokens=512, |
|
**model_inputs |
|
) |
|
|
|
new_tokens = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
generated_translation = self.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) |
|
|
|
temp.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
model_inputs.to('cpu') |
|
del model_inputs |
|
else: |
|
for target_language in target_languages: |
|
generated_translation = filter_pipeline.batch_decoder(batch) |
|
print(generated_translation) |
|
print(len(generated_translation)) |
|
|
|
temp.append({ |
|
"target_language": target_language, |
|
"generated_translation": generated_translation, |
|
}) |
|
temp_outputs.append(temp) |
|
processed_num += len(batch) |
|
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1: |
|
print("Already processed number: ", len(temp_outputs)) |
|
process_gpu_translate_result(temp_outputs) |
|
outputs = [] |
|
for temp_output in temp_outputs: |
|
length = len(temp_output[0]["generated_translation"]) |
|
for i in range(length): |
|
temp = [] |
|
for trans in temp_output: |
|
temp.append({ |
|
"target_language": trans["target_language"], |
|
"generated_translation": trans['generated_translation'][i], |
|
}) |
|
outputs.append(temp) |
|
return outputs |
|
for filter in self._filter_list: |
|
inputs = filter.encoder(inputs) |
|
return inputs |
|
|
|
def batch_decoder(self, inputs): |
|
for filter in reversed(self._filter_list): |
|
inputs = filter.decoder(inputs) |
|
return inputs |