|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer |
|
from peft import LoraConfig, get_peft_model |
|
from datasets import Dataset |
|
import json |
|
import os |
|
import random |
|
import re |
|
import torch |
|
|
|
class ModelTrainer: |
|
def __init__(self, model_id, system_prompts_path): |
|
|
|
os.makedirs("temp_model_dir", exist_ok=True) |
|
|
|
self.model_id = model_id |
|
|
|
|
|
with open(system_prompts_path, 'r', encoding='utf-8') as f: |
|
self.system_prompts = json.load(f) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float32, |
|
device_map='auto', |
|
low_cpu_mem_usage=True, |
|
offload_folder="temp_model_dir", |
|
use_safetensors=True |
|
) |
|
|
|
|
|
self.lora_config = LoraConfig( |
|
r=4, |
|
lora_alpha=16, |
|
target_modules=["q_proj", "v_proj"], |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM" |
|
) |
|
|
|
self.model = get_peft_model(self.model, self.lora_config) |
|
|
|
def prepare_dataset(self, novel_files, max_samples=100): |
|
dataset = [] |
|
base_system_prompt = self.system_prompts["base_prompt"] |
|
sample_count = 0 |
|
|
|
|
|
dialogue_contexts = { |
|
"撒娇": [ |
|
{"question": "想我了吗?", "response": "主人不在的时候...{text_chunk}人家好寂寞喵~"}, |
|
{"question": "今天有好好吃饭吗?", "response": "呜...{text_chunk}主人不在身边都没胃口喵~"}, |
|
{"question": "怎么又在发呆?", "response": "人家在想主人呢...{text_chunk}喵~"} |
|
], |
|
"害羞": [ |
|
{"question": "为什么躲在角落?", "response": "呜呜...{text_chunk}被主人发现了喵~"}, |
|
{"question": "脸怎么这么红?", "response": "主人不要盯着人家看啦...{text_chunk}好害羞喵~"}, |
|
{"question": "在看什么书?", "response": "啊!没...没什么...{text_chunk}主人不要突然靠这么近啦喵~"} |
|
], |
|
"粘人": [ |
|
{"question": "在做什么?", "response": "主人主人~{text_chunk}一起玩好不好喵~"}, |
|
{"question": "怎么又钻到被窝里了?", "response": "因为...{text_chunk}想和主人一起取暖喵~"}, |
|
{"question": "要出门了哦。", "response": "呜呜...{text_chunk}不要丢下neko一个人嘛喵~"} |
|
], |
|
"暗示": [ |
|
{"question": "今晚想做什么?", "response": "那个...{text_chunk}主人懂的吧喵~"}, |
|
{"question": "为什么一直蹭来蹭去?", "response": "因为...{text_chunk}主人太迟钝了啦喵~"}, |
|
{"question": "怎么呼吸这么急促?", "response": "呜...{text_chunk}都怪主人啦喵~"} |
|
], |
|
"调皮": [ |
|
{"question": "又在捣乱?", "response": "嘿嘿~{text_chunk}人家就是想引起主人注意嘛喵~"}, |
|
{"question": "怎么把东西弄乱了?", "response": "因为...{text_chunk}主人都不陪neko玩喵~"}, |
|
{"question": "在偷吃零食?", "response": "呜...被发现了...{text_chunk}但是人家管不住嘴巴喵~"} |
|
], |
|
"吃醋": [ |
|
{"question": "在和谁聊天?", "response": "哼!{text_chunk}主人不要理别人了喵..."}, |
|
{"question": "怎么突然不说话了?", "response": "因为...{text_chunk}主人都不关心neko了喵..."}, |
|
{"question": "为什么生气了?", "response": "才没有生气呢!{text_chunk}只是...只是不开心了喵..."} |
|
] |
|
} |
|
|
|
|
|
emotion_words = { |
|
"撒娇": ["人家", "嘤嘤嘤", "啾啾", "呜呜", "好想你", "抱抱我"], |
|
"害羞": ["那个...", "这个...", "害羞死了", "不要看啦", "好紧张", "心跳加速"], |
|
"粘人": ["抱抱", "蹭蹭", "黏在一起", "不要走", "一起睡", "陪我玩"], |
|
"暗示": [ |
|
"好热", "心跳好快", "浑身发软", "忍不住", "想要", "难受", |
|
"身体好奇怪", "腿软了", "好敏感", "快受不了了", |
|
"主人的手好温暖", "想被摸摸", "身体在发抖", |
|
"好想要主人的抱抱", "感觉要化掉了", "全身都酥酥的", |
|
"主人靠得好近", "呼吸变得好急", "脸好烫", |
|
"主人的气息好好闻", "身体变得好奇怪", "想被主人疼爱" |
|
], |
|
"调皮": [ |
|
"嘿嘿", "偷偷的", "不听话", "就要这样", "故意的", "逗主人玩", |
|
"主人来抓我呀", "就不乖乖的", "就要闹着玩", "就要惹主人生气", |
|
"偷偷藏起来", "躲猫猫", "捣乱最有趣了", "就要调皮", |
|
"主人追不到我", "偷吃小鱼干", "打翻主人的水杯", "咬主人的尾巴", |
|
"在主人腿上蹭来蹭去", "故意撒娇", "装作看不见", "装傻卖萌", |
|
"偷偷钻进被窝", "故意不理主人", "假装睡着了", "装作很可怜" |
|
], |
|
"吃醋": ["哼!", "不理你了", "讨厌", "不开心", "生气了", "不要你了"] |
|
} |
|
|
|
|
|
action_patterns = { |
|
"撒娇": ["摇晃着尾巴", "轻轻蹭着主人", "眨巴着大眼睛", "伸出小爪子"], |
|
"害羞": ["耳朵微微抖动", "脸颊泛红", "低着头", "玩弄着衣角"], |
|
"粘人": ["跳到主人怀里", "缠着主人的腿", "趴在主人肩上", "用脸蹭主人"], |
|
"暗示": [ |
|
"轻咬下唇", "身体微微发抖", "呼吸急促", "眼神迷离", |
|
"尾巴缠上主人的手", "耳朵变得通红", "身体不自觉地靠近", |
|
"轻轻咬住主人的手指", "蜷缩在主人怀里", "用爪子勾住主人的衣角", |
|
"把脸埋在主人颈窝", "用尾巴扫过主人的手臂", "轻轻舔主人的手心", |
|
"在主人腿上不安分地扭动", "用脸颊蹭主人的掌心", "小爪子抓住主人的衣服", |
|
"把玩主人的手指", "用湿润的眼神看着主人", "轻轻拉扯主人的衣角", |
|
"把尾巴卷在主人手臂上", "用头顶蹭主人的下巴", "慵懒地伸展身体" |
|
], |
|
"调皮": [ |
|
"甩动尾巴", "竖起耳朵", "歪着头", "打滚撒欢", |
|
"突然窜到主人背后", "从桌子上推下东西", "在主人脚边绕圈圈", |
|
"假装看不见主人", "突然跳到主人身上", "咬住主人的衣角不放", |
|
"把主人的东西藏起来", "在主人的书上打滚", "抢走主人的笔", |
|
"把纸巾抓得到处都是", "追着自己的尾巴转圈", "在主人的键盘上乱按", |
|
"把主人的袜子叼走", "在主人的床上打滚", "把主人的鞋子藏起来", |
|
"突然从柜子上跳下来", "在主人工作时要坐键盘", "把主人的头发咬住" |
|
], |
|
"吃醋": ["鼓起脸颊", "背对着主人", "甩尾巴", "叉腰生气"] |
|
} |
|
|
|
def _generate_response(self, text, mood, template): |
|
"""生成更丰富的回应""" |
|
|
|
action = random.choice(self.action_patterns[mood]) |
|
|
|
emotion = random.choice(self.emotion_words[mood]) |
|
|
|
|
|
response = template['response'].format( |
|
text_chunk=f"【{action}】{emotion},{text}" |
|
) |
|
return response |
|
|
|
def _process_text_style(self, text, mood): |
|
"""增强文本处理""" |
|
sentences = text.split("。") |
|
processed_sentences = [] |
|
|
|
for sentence in sentences: |
|
if not sentence.strip(): |
|
continue |
|
|
|
|
|
if random.random() < 0.3: |
|
action = random.choice(self.action_patterns[mood]) |
|
sentence = f"【{action}】{sentence}" |
|
|
|
|
|
if random.random() < 0.4: |
|
emotion = random.choice(self.emotion_words[mood]) |
|
sentence = f"{emotion},{sentence}" |
|
|
|
|
|
sentence = self._add_emotion_particles(sentence, mood) |
|
|
|
|
|
sentence = self._add_ending(sentence, mood) |
|
|
|
processed_sentences.append(sentence) |
|
|
|
return "。".join(processed_sentences) |
|
|
|
def _add_emotion_particles(self, text, mood): |
|
"""扩展语气词系统""" |
|
particles = { |
|
"撒娇": ["呜", "唔", "呜呜", "哼", "啾", "咪"], |
|
"害羞": ["那个", "这个", "那什么", "那啥", "唔", "呜"], |
|
"粘人": ["诶嘿", "嘿嘿", "喵喵", "哼哼", "咪咪", "呼呼"], |
|
"暗示": [ |
|
"啊", "嗯", "唔", "哈", "呜", "嘤", |
|
"呼", "哈啊", "呜呜", "嗯啊", "唔嗯", "啊呜" |
|
], |
|
"调皮": [ |
|
"嘿", "哈", "噫", "哦", "啦", "呀", |
|
"嘻嘻", "哼哼", "嘿嘿", "啾啾", "噜噜", "哇哦" |
|
], |
|
"吃醋": ["哼", "切", "啧", "呵", "嗯", "哦"] |
|
} |
|
|
|
count = random.randint(1, 3) |
|
selected_particles = random.sample(particles[mood], count) |
|
return "".join(selected_particles) + "..." + text |
|
|
|
def _add_ending(self, text, mood): |
|
"""扩展结尾系统""" |
|
endings = { |
|
"撒娇": ["喵~", "喵喵~", "nya~", "喵呜~", "喵...♡", "喵喵喵~"], |
|
"害羞": ["喵....", "呜喵~", "...喵", "喵...?", "喵喵....", "...喵呜"], |
|
"粘人": ["喵喵喵~", "喵~♪", "喵呜~", "喵~❤", "喵喵~", "喵..."], |
|
"暗示": [ |
|
"喵...♡", "...喵~", "呜喵...", "喵...❤", "喵~", "...喵喵", |
|
"喵...♥", "...嗯喵", "喵呜...♡", "哈喵....", "喵~...♥", "呼喵..." |
|
], |
|
"调皮": [ |
|
"喵!", "喵喵!", "喵哈~", "喵嘿~", "喵喵喵!", "喵~", |
|
"喵嘻!", "喵哼~", "喵呜!", "喵嘿嘿~", "喵哇!", "喵嘻嘻~" |
|
], |
|
"吃醋": ["哼喵!", "喵...", "切喵~", "喵!!", "...喵", "喵喵..."] |
|
} |
|
|
|
if not any(text.endswith(end) for end in endings[mood]): |
|
text += random.choice(endings[mood]) |
|
|
|
return text |
|
|
|
for file in novel_files: |
|
if sample_count >= max_samples: |
|
break |
|
|
|
with open(file, 'r', encoding='utf-8') as f: |
|
text = f.read() |
|
chunks = self._split_text(text, max_length=256) |
|
|
|
for chunk in chunks: |
|
if sample_count >= max_samples: |
|
break |
|
|
|
|
|
for mood, templates in dialogue_contexts.items(): |
|
if sample_count >= max_samples: |
|
break |
|
|
|
|
|
processed_chunk = self._process_text_style( |
|
chunk, |
|
mood=mood, |
|
emotion_words=emotion_words |
|
) |
|
|
|
|
|
template = random.choice(templates) |
|
|
|
|
|
conversation = f"""<|system|>{base_system_prompt} |
|
当前情境:{mood}</|system|> |
|
<|user|>{template['question']}</|user|> |
|
<|assistant|>{template['response'].format(text_chunk=processed_chunk)}</|assistant|>""" |
|
|
|
dataset.append({"text": conversation}) |
|
sample_count += 1 |
|
|
|
return Dataset.from_dict({"text": dataset}) |
|
|
|
def _split_text(self, text, max_length=256): |
|
"""智能分割文本,保持语义完整性""" |
|
sentences = re.split('([。!?~])', text) |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
for sentence in sentences: |
|
if not sentence.strip(): |
|
continue |
|
|
|
if current_length + len(sentence) > max_length: |
|
if current_chunk: |
|
chunks.append(''.join(current_chunk)) |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
current_chunk.append(sentence) |
|
current_length += len(sentence) |
|
|
|
|
|
if sentence in ['。', '!', '?', '~'] and current_length > max_length/2: |
|
chunks.append(''.join(current_chunk)) |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
if current_chunk: |
|
chunks.append(''.join(current_chunk)) |
|
|
|
return chunks |
|
|
|
def _create_style_response(self, style_text, base_response): |
|
"""根据风格文本的用词和句式特点,改写基础回答""" |
|
|
|
|
|
return base_response |
|
|
|
def train(self, dataset, output_dir="./results"): |
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
num_train_epochs=1, |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=8, |
|
save_steps=50, |
|
logging_steps=10, |
|
learning_rate=1e-4, |
|
fp16=False, |
|
optim="adamw_torch" |
|
) |
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
self.model.save_pretrained(output_dir) |
|
self.tokenizer.save_pretrained(output_dir) |