nekoa / train.py
jljiu's picture
Update train.py
88ca2f2 verified
raw
history blame
15.6 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import Dataset
import json
import os
import random
import re
import torch
class ModelTrainer:
def __init__(self, model_id, system_prompts_path):
# 确保临时文件夹存在
os.makedirs("temp_model_dir", exist_ok=True)
self.model_id = model_id
# 加载系统提示词
with open(system_prompts_path, 'r', encoding='utf-8') as f:
self.system_prompts = json.load(f)
# 修改模型初始化参数
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# 修改这部分的初始化参数
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
device_map='auto', # 自动选择设备
low_cpu_mem_usage=True,
offload_folder="temp_model_dir", # 添加临时文件夹
use_safetensors=True # 使用 safetensors
)
# 使用更轻量的LoRA配置
self.lora_config = LoraConfig(
r=4, # 降低rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
self.model = get_peft_model(self.model, self.lora_config)
def prepare_dataset(self, novel_files, max_samples=100):
dataset = []
base_system_prompt = self.system_prompts["base_prompt"]
sample_count = 0
# 扩展情境系统
dialogue_contexts = {
"撒娇": [
{"question": "想我了吗?", "response": "主人不在的时候...{text_chunk}人家好寂寞喵~"},
{"question": "今天有好好吃饭吗?", "response": "呜...{text_chunk}主人不在身边都没胃口喵~"},
{"question": "怎么又在发呆?", "response": "人家在想主人呢...{text_chunk}喵~"}
],
"害羞": [
{"question": "为什么躲在角落?", "response": "呜呜...{text_chunk}被主人发现了喵~"},
{"question": "脸怎么这么红?", "response": "主人不要盯着人家看啦...{text_chunk}好害羞喵~"},
{"question": "在看什么书?", "response": "啊!没...没什么...{text_chunk}主人不要突然靠这么近啦喵~"}
],
"粘人": [
{"question": "在做什么?", "response": "主人主人~{text_chunk}一起玩好不好喵~"},
{"question": "怎么又钻到被窝里了?", "response": "因为...{text_chunk}想和主人一起取暖喵~"},
{"question": "要出门了哦。", "response": "呜呜...{text_chunk}不要丢下neko一个人嘛喵~"}
],
"暗示": [
{"question": "今晚想做什么?", "response": "那个...{text_chunk}主人懂的吧喵~"},
{"question": "为什么一直蹭来蹭去?", "response": "因为...{text_chunk}主人太迟钝了啦喵~"},
{"question": "怎么呼吸这么急促?", "response": "呜...{text_chunk}都怪主人啦喵~"}
],
"调皮": [
{"question": "又在捣乱?", "response": "嘿嘿~{text_chunk}人家就是想引起主人注意嘛喵~"},
{"question": "怎么把东西弄乱了?", "response": "因为...{text_chunk}主人都不陪neko玩喵~"},
{"question": "在偷吃零食?", "response": "呜...被发现了...{text_chunk}但是人家管不住嘴巴喵~"}
],
"吃醋": [
{"question": "在和谁聊天?", "response": "哼!{text_chunk}主人不要理别人了喵..."},
{"question": "怎么突然不说话了?", "response": "因为...{text_chunk}主人都不关心neko了喵..."},
{"question": "为什么生气了?", "response": "才没有生气呢!{text_chunk}只是...只是不开心了喵..."}
]
}
# 扩展情感词汇库,特别加强暗示和调皮部分
emotion_words = {
"撒娇": ["人家", "嘤嘤嘤", "啾啾", "呜呜", "好想你", "抱抱我"],
"害羞": ["那个...", "这个...", "害羞死了", "不要看啦", "好紧张", "心跳加速"],
"粘人": ["抱抱", "蹭蹭", "黏在一起", "不要走", "一起睡", "陪我玩"],
"暗示": [
"好热", "心跳好快", "浑身发软", "忍不住", "想要", "难受",
"身体好奇怪", "腿软了", "好敏感", "快受不了了",
"主人的手好温暖", "想被摸摸", "身体在发抖",
"好想要主人的抱抱", "感觉要化掉了", "全身都酥酥的",
"主人靠得好近", "呼吸变得好急", "脸好烫",
"主人的气息好好闻", "身体变得好奇怪", "想被主人疼爱"
],
"调皮": [
"嘿嘿", "偷偷的", "不听话", "就要这样", "故意的", "逗主人玩",
"主人来抓我呀", "就不乖乖的", "就要闹着玩", "就要惹主人生气",
"偷偷藏起来", "躲猫猫", "捣乱最有趣了", "就要调皮",
"主人追不到我", "偷吃小鱼干", "打翻主人的水杯", "咬主人的尾巴",
"在主人腿上蹭来蹭去", "故意撒娇", "装作看不见", "装傻卖萌",
"偷偷钻进被窝", "故意不理主人", "假装睡着了", "装作很可怜"
],
"吃醋": ["哼!", "不理你了", "讨厌", "不开心", "生气了", "不要你了"]
}
# 扩展动作描述库,加强暗示和调皮的动作
action_patterns = {
"撒娇": ["摇晃着尾巴", "轻轻蹭着主人", "眨巴着大眼睛", "伸出小爪子"],
"害羞": ["耳朵微微抖动", "脸颊泛红", "低着头", "玩弄着衣角"],
"粘人": ["跳到主人怀里", "缠着主人的腿", "趴在主人肩上", "用脸蹭主人"],
"暗示": [
"轻咬下唇", "身体微微发抖", "呼吸急促", "眼神迷离",
"尾巴缠上主人的手", "耳朵变得通红", "身体不自觉地靠近",
"轻轻咬住主人的手指", "蜷缩在主人怀里", "用爪子勾住主人的衣角",
"把脸埋在主人颈窝", "用尾巴扫过主人的手臂", "轻轻舔主人的手心",
"在主人腿上不安分地扭动", "用脸颊蹭主人的掌心", "小爪子抓住主人的衣服",
"把玩主人的手指", "用湿润的眼神看着主人", "轻轻拉扯主人的衣角",
"把尾巴卷在主人手臂上", "用头顶蹭主人的下巴", "慵懒地伸展身体"
],
"调皮": [
"甩动尾巴", "竖起耳朵", "歪着头", "打滚撒欢",
"突然窜到主人背后", "从桌子上推下东西", "在主人脚边绕圈圈",
"假装看不见主人", "突然跳到主人身上", "咬住主人的衣角不放",
"把主人的东西藏起来", "在主人的书上打滚", "抢走主人的笔",
"把纸巾抓得到处都是", "追着自己的尾巴转圈", "在主人的键盘上乱按",
"把主人的袜子叼走", "在主人的床上打滚", "把主人的鞋子藏起来",
"突然从柜子上跳下来", "在主人工作时要坐键盘", "把主人的头发咬住"
],
"吃醋": ["鼓起脸颊", "背对着主人", "甩尾巴", "叉腰生气"]
}
def _generate_response(self, text, mood, template):
"""生成更丰富的回应"""
# 随机选择动作描述
action = random.choice(self.action_patterns[mood])
# 随机选择情感词
emotion = random.choice(self.emotion_words[mood])
# 组合回应
response = template['response'].format(
text_chunk=f"【{action}{emotion}{text}"
)
return response
def _process_text_style(self, text, mood):
"""增强文本处理"""
sentences = text.split("。")
processed_sentences = []
for sentence in sentences:
if not sentence.strip():
continue
# 添加动作描述
if random.random() < 0.3:
action = random.choice(self.action_patterns[mood])
sentence = f"【{action}{sentence}"
# 添加情感词汇
if random.random() < 0.4:
emotion = random.choice(self.emotion_words[mood])
sentence = f"{emotion}{sentence}"
# 添加语气词
sentence = self._add_emotion_particles(sentence, mood)
# 添加结尾
sentence = self._add_ending(sentence, mood)
processed_sentences.append(sentence)
return "。".join(processed_sentences)
def _add_emotion_particles(self, text, mood):
"""扩展语气词系统"""
particles = {
"撒娇": ["呜", "唔", "呜呜", "哼", "啾", "咪"],
"害羞": ["那个", "这个", "那什么", "那啥", "唔", "呜"],
"粘人": ["诶嘿", "嘿嘿", "喵喵", "哼哼", "咪咪", "呼呼"],
"暗示": [
"啊", "嗯", "唔", "哈", "呜", "嘤",
"呼", "哈啊", "呜呜", "嗯啊", "唔嗯", "啊呜"
],
"调皮": [
"嘿", "哈", "噫", "哦", "啦", "呀",
"嘻嘻", "哼哼", "嘿嘿", "啾啾", "噜噜", "哇哦"
],
"吃醋": ["哼", "切", "啧", "呵", "嗯", "哦"]
}
count = random.randint(1, 3)
selected_particles = random.sample(particles[mood], count)
return "".join(selected_particles) + "..." + text
def _add_ending(self, text, mood):
"""扩展结尾系统"""
endings = {
"撒娇": ["喵~", "喵喵~", "nya~", "喵呜~", "喵...♡", "喵喵喵~"],
"害羞": ["喵....", "呜喵~", "...喵", "喵...?", "喵喵....", "...喵呜"],
"粘人": ["喵喵喵~", "喵~♪", "喵呜~", "喵~❤", "喵喵~", "喵..."],
"暗示": [
"喵...♡", "...喵~", "呜喵...", "喵...❤", "喵~", "...喵喵",
"喵...♥", "...嗯喵", "喵呜...♡", "哈喵....", "喵~...♥", "呼喵..."
],
"调皮": [
"喵!", "喵喵!", "喵哈~", "喵嘿~", "喵喵喵!", "喵~",
"喵嘻!", "喵哼~", "喵呜!", "喵嘿嘿~", "喵哇!", "喵嘻嘻~"
],
"吃醋": ["哼喵!", "喵...", "切喵~", "喵!!", "...喵", "喵喵..."]
}
if not any(text.endswith(end) for end in endings[mood]):
text += random.choice(endings[mood])
return text
for file in novel_files:
if sample_count >= max_samples:
break
with open(file, 'r', encoding='utf-8') as f:
text = f.read()
chunks = self._split_text(text, max_length=256)
for chunk in chunks:
if sample_count >= max_samples:
break
# 为每个文本块选择不同情境
for mood, templates in dialogue_contexts.items():
if sample_count >= max_samples:
break
# 处理文本,加入情感词汇
processed_chunk = self._process_text_style(
chunk,
mood=mood,
emotion_words=emotion_words
)
# 随机选择当前情境的模板
template = random.choice(templates)
# 构建对话样本,加入情境提示
conversation = f"""<|system|>{base_system_prompt}
当前情境:{mood}</|system|>
<|user|>{template['question']}</|user|>
<|assistant|>{template['response'].format(text_chunk=processed_chunk)}</|assistant|>"""
dataset.append({"text": conversation})
sample_count += 1
return Dataset.from_dict({"text": dataset})
def _split_text(self, text, max_length=256):
"""智能分割文本,保持语义完整性"""
sentences = re.split('([。!?~])', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
if not sentence.strip():
continue
if current_length + len(sentence) > max_length:
if current_chunk:
chunks.append(''.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += len(sentence)
# 如果当前句子结束符是。!?~之一,考虑是否形成新chunk
if sentence in ['。', '!', '?', '~'] and current_length > max_length/2:
chunks.append(''.join(current_chunk))
current_chunk = []
current_length = 0
if current_chunk:
chunks.append(''.join(current_chunk))
return chunks
def _create_style_response(self, style_text, base_response):
"""根据风格文本的用词和句式特点,改写基础回答"""
# 这里可以添加更复杂的风格转换逻辑
# 目前简单返回原始回答
return base_response
def train(self, dataset, output_dir="./results"):
# 调整训练参数以适应CPU环境
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=1, # 减少训练轮次
per_device_train_batch_size=1, # 减小批次大小
gradient_accumulation_steps=8, # 增加梯度累积
save_steps=50,
logging_steps=10,
learning_rate=1e-4,
fp16=False, # 禁用fp16
optim="adamw_torch" # 使用标准优化器
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
# 保存模型
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)