princepride's picture
Update model.py
f3f7af4 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from modules.file import ExcelFileWriter
import os
from abc import ABC, abstractmethod
from typing import List
import re
class FilterPipeline():
def __init__(self, filter_list):
self._filter_list:List[Filter] = filter_list
def append(self, filter):
self._filter_list.append(filter)
def batch_encoder(self, inputs):
for filter in self._filter_list:
inputs = filter.encoder(inputs)
return inputs
def batch_decoder(self, inputs):
for filter in reversed(self._filter_list):
inputs = filter.decoder(inputs)
return inputs
class Filter(ABC):
# 抽象基类,用于定义过滤器的基本接口
def __init__(self):
self.name = 'filter' # 过滤器的名称
self.code = [] # 存储过滤或编码信息
@abstractmethod
def encoder(self, inputs):
# 抽象方法,编码或过滤输入的接口
pass
@abstractmethod
def decoder(self, inputs):
# 抽象方法,解码或还原输入的接口
pass
class SpecialTokenFilter(Filter):
# 特殊字符过滤器,用于过滤特定的特殊字符字符串
def __init__(self):
self.name = 'special token filter'
self.code = []
self.special_tokens = ['!', '!', '-'] # 定义特殊字符集
def encoder(self, inputs):
# 编码函数,过滤掉仅包含特殊字符的字符串
filtered_inputs = []
self.code = []
for i, input_str in enumerate(inputs):
if not all(char in self.special_tokens for char in input_str):
filtered_inputs.append(input_str)
else:
self.code.append([i, input_str]) # 将特殊字符字符串的位置和内容保存
return filtered_inputs
def decoder(self, inputs):
# 解码函数,将被过滤的特殊字符字符串还原
original_inputs = inputs.copy()
for removed_indice in self.code:
original_inputs.insert(removed_indice[0], removed_indice[1]) # 恢复原始位置的字符串
return original_inputs
class SperSignFilter(Filter):
# 特殊标记过滤器,用于处理包含 '%s' 的字符串
def __init__(self):
self.name = 's percentage sign filter'
self.code = []
def encoder(self, inputs):
# 编码函数,将 '%s' 替换为 '*'
encoded_inputs = []
self.code = []
for i, input_str in enumerate(inputs):
if '%s' in input_str:
encoded_str = input_str.replace('%s', '*')
self.code.append(i) # 保存包含 '%s' 的字符串位置
else:
encoded_str = input_str
encoded_inputs.append(encoded_str)
return encoded_inputs
def decoder(self, inputs):
# 解码函数,将 '*' 还原为 '%s'
decoded_inputs = inputs.copy()
for i in self.code:
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s')
return decoded_inputs
class ParenSParenFilter(Filter):
# 特殊字符串过滤器,用于处理 '(s)' 的字符串
def __init__(self):
self.name = 'Paren s paren filter'
self.code = []
def encoder(self, inputs):
# 编码函数,将 '(s)' 替换为 '$'
encoded_inputs = []
self.code = []
for i, input_str in enumerate(inputs):
if '(s)' in input_str:
encoded_str = input_str.replace('(s)', '$')
self.code.append(i) # 保存包含 '(s)' 的字符串位置
else:
encoded_str = input_str
encoded_inputs.append(encoded_str)
return encoded_inputs
def decoder(self, inputs):
# 解码函数,将 '$' 还原为 '(s)'
decoded_inputs = inputs.copy()
for i in self.code:
decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)')
return decoded_inputs
class ChevronsFilter(Filter):
# 尖括号过滤器,用于处理包含 '<>' 内容的字符串
def __init__(self):
self.name = 'chevrons filter'
self.code = []
def encoder(self, inputs):
# 编码函数,将尖括号内的内容替换为 '#'
encoded_inputs = []
self.code = []
pattern = re.compile(r'<.*?>')
for i, input_str in enumerate(inputs):
if pattern.search(input_str):
matches = pattern.findall(input_str)
encoded_str = pattern.sub('#', input_str)
self.code.append((i, matches)) # 保存匹配内容的位置和内容
else:
encoded_str = input_str
encoded_inputs.append(encoded_str)
return encoded_inputs
def decoder(self, inputs):
# 解码函数,将 '#' 还原为尖括号内的原内容
decoded_inputs = inputs.copy()
for i, matches in self.code:
for match in matches:
decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1)
return decoded_inputs
class SimilarFilter(Filter):
# 相似字符串过滤器,用于处理只在数字上有区别的字符串
def __init__(self):
self.name = 'similar filter'
self.code = []
def is_similar(self, str1, str2):
# 判断两个字符串是否相似(忽略数字)
pattern = re.compile(r'\d+')
return pattern.sub('', str1) == pattern.sub('', str2)
def encoder(self, inputs):
# 编码函数,检测连续的相似字符串,记录索引和内容
encoded_inputs = []
self.code = []
i = 0
while i < len(inputs):
encoded_inputs.append(inputs[i])
similar_strs = [inputs[i]]
j = i + 1
while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
similar_strs.append(inputs[j])
j += 1
if len(similar_strs) > 1:
self.code.append((i, similar_strs))
i = j
return encoded_inputs
def decoder(self, inputs):
# 解码函数,将被检测的相似字符串插回原位置
decoded_inputs = inputs
for i, similar_strs in self.code:
pattern = re.compile(r'\d+')
for j in range(len(similar_strs)):
if pattern.search(similar_strs[j]):
number = re.findall(r'\d+', similar_strs[j])[0]
new_str = pattern.sub(number, inputs[i])
else:
new_str = inputs[i]
if j > 0:
decoded_inputs.insert(i + j, new_str)
return decoded_inputs
class ChineseFilter:
# 中文拼音过滤器,用于检测并过滤中文拼音单词
def __init__(self, pinyin_lib_file='pinyin.txt'):
self.name = 'chinese filter'
self.code = []
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) # 加载拼音库
def load_pinyin_lib(self, file_path):
# 加载拼音库文件到内存中
with open(os.path.join(script_dir, file_path), 'r', encoding='utf-8') as f:
return set(line.strip().lower() for line in f)
def is_valid_chinese(self, word):
# 判断一个单词是否符合要求: 单词仅由一个单词构成且首字母大写
if len(word.split()) == 1 and word[0].isupper():
return self.is_pinyin(word.lower())
return False
def encoder(self, inputs):
# 编码函数,检测并过滤符合拼音规则的中文单词
encoded_inputs = []
self.code = []
for i, word in enumerate(inputs):
if self.is_valid_chinese(word):
self.code.append((i, word)) # 保存符合要求的中文单词及其索引
else:
encoded_inputs.append(word)
return encoded_inputs
def decoder(self, inputs):
# 解码函数,将符合拼音规则的中文单词还原到原位置
decoded_inputs = inputs.copy()
for i, word in self.code:
decoded_inputs.insert(i, word)
return decoded_inputs
def is_pinyin(self, string):
# 判断字符串是否是拼音或英文单词
string = string.lower()
stringlen = len(string)
max_len = 6
result = []
n = 0
while n < stringlen:
matched = 0
temp_result = []
for i in range(max_len, 0, -1):
s = string[0:i]
if s in self.pinyin_lib:
temp_result.append(string[:i])
matched = i
break
if i == 1 and len(temp_result) == 0:
return False
result.extend(temp_result)
string = string[matched:]
n += matched
return True
# 定义脚本目录的路径,供拼音文件加载使用
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
class Model():
def __init__(self, modelname, selected_lora_model, selected_gpu):
def get_gpu_index(gpu_info, target_gpu_name):
"""
从 GPU 信息中获取目标 GPU 的索引
Args:
gpu_info (list): 包含 GPU 名称的列表
target_gpu_name (str): 目标 GPU 的名称
Returns:
int: 目标 GPU 的索引,如果未找到则返回 -1
"""
for i, name in enumerate(gpu_info):
if target_gpu_name.lower() in name.lower():
return i
return -1
if selected_gpu != "cpu":
gpu_count = torch.cuda.device_count()
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
self.device_name = f"cuda:{selected_gpu_index}"
else:
self.device_name = "cpu"
print("device_name", self.device_name)
self.model = AutoModelForCausalLM.from_pretrained(modelname, torch_dtype="auto").to(self.device_name)
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
# self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
def generate(self, inputs, original_language, target_languages, max_batch_size):
filter_list = [SpecialTokenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
filter_pipeline = FilterPipeline(filter_list)
def process_gpu_translate_result(temp_outputs):
outputs = []
for temp_output in temp_outputs:
length = len(temp_output[0]["generated_translation"])
for i in range(length):
temp = []
for trans in temp_output:
temp.append({
"target_language": trans["target_language"],
"generated_translation": trans['generated_translation'][i],
})
outputs.append(temp)
excel_writer = ExcelFileWriter()
excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
if self.device_name == "cpu":
# Tokenize input
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
output = []
for target_language in target_languages:
# Get language code for the target language
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
# Generate translation
generated_tokens = self.model.generate(
**input_ids,
forced_bos_token_id=target_lang_code,
max_length=128
)
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# Append result to output
output.append({
"target_language": target_language,
"generated_translation": generated_translation,
})
outputs = []
length = len(output[0]["generated_translation"])
for i in range(length):
temp = []
for trans in output:
temp.append({
"target_language": trans["target_language"],
"generated_translation": trans['generated_translation'][i],
})
outputs.append(temp)
return outputs
else:
# 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
# max_batch_size = 10
# Ensure batch size is within model limits:
print("length of inputs: ",len(inputs))
batch_size = min(len(inputs), int(max_batch_size))
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
print("length of batches size: ", len(batches))
temp_outputs = []
processed_num = 0
for index, batch in enumerate(batches):
# Tokenize input
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
print(len(batch))
print(batch)
batch = filter_pipeline.batch_encoder(batch)
print(batch)
temp = []
if len(batch) > 0:
for target_language in target_languages:
batch_messages = [[
{"role": "system", "content": f"You are an expert in translating {original_language} to {target_language} for ERP systems. Your task is to translate markdown-formatted text from {original_language} to {target_language}. The text to be translated may not necessarily be complete phrases or sentences, but you must translate it into the corresponding language based on your own understanding, preserving its formatting without adding extra content."},
{"role": "user", "content": input},
] for input in batch]
batch_texts = [self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
) for messages in batch_messages]
self.tokenizer.padding_side = "left"
model_inputs = self.tokenizer(
batch_texts,
return_tensors="pt",
padding="longest",
truncation=True,
).to(self.device_name)
generated_ids = self.model.generate(
max_new_tokens=512,
**model_inputs
)
# Calculate the length of new tokens generated for each sequence
new_tokens = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
generated_translation = self.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
# Append result to output
temp.append({
"target_language": target_language,
"generated_translation": generated_translation,
})
model_inputs.to('cpu')
del model_inputs
else:
for target_language in target_languages:
generated_translation = filter_pipeline.batch_decoder(batch)
print(generated_translation)
print(len(generated_translation))
# Append result to output
temp.append({
"target_language": target_language,
"generated_translation": generated_translation,
})
temp_outputs.append(temp)
processed_num += len(batch)
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
print("Already processed number: ", len(temp_outputs))
process_gpu_translate_result(temp_outputs)
outputs = []
for temp_output in temp_outputs:
length = len(temp_output[0]["generated_translation"])
for i in range(length):
temp = []
for trans in temp_output:
temp.append({
"target_language": trans["target_language"],
"generated_translation": trans['generated_translation'][i],
})
outputs.append(temp)
return outputs
for filter in self._filter_list:
inputs = filter.encoder(inputs)
return inputs
def batch_decoder(self, inputs):
for filter in reversed(self._filter_list):
inputs = filter.decoder(inputs)
return inputs