import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import json import os from train import ModelTrainer class NovelAIApp: def __init__(self): self.model = None self.tokenizer = None self.trainer = None # 加载系统提示词 with open('configs/system_prompts.json', 'r', encoding='utf-8') as f: self.system_prompts = json.load(f) def load_model(self, model_path): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained( model_path, load_in_8bit=True, device_map="auto" ) def train_model(self, files): if not self.trainer: self.trainer = ModelTrainer( "CohereForAI/c4ai-command-r-plus-08-2024", "configs/system_prompts.json" ) dataset = self.trainer.prepare_dataset(files) self.trainer.train(dataset) return "训练完成!" def generate_text(self, prompt, system_prompt_type="creative"): if not self.model: return "请先加载模型!" system_prompt = self.system_prompts.get(system_prompt_type, self.system_prompts["base_prompt"]) formatted_prompt = f"""<|system|>{system_prompt} <|user|>{prompt} <|assistant|>""" inputs = self.tokenizer(formatted_prompt, return_tensors="pt") outputs = self.model.generate( inputs["input_ids"], max_length=512, temperature=0.7, top_p=0.9, repetition_penalty=1.1, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def create_interface(self): with gr.Blocks() as interface: gr.Markdown("# 风格化对话助手") with gr.Tab("模型训练"): gr.Markdown("""### 上传参考文本 上传文本文件来训练模型学习特定的语言风格。 建议上传具有鲜明语言特色的文本。""") file_output = gr.File( file_count="multiple", label="上传参考文本文件" ) train_button = gr.Button("开始训练") train_output = gr.Textbox(label="训练状态") with gr.Tab("对话"): gr.Markdown("与助手进行对话,体验风格化的语言表达") style_select = gr.Dropdown( choices=["formal", "casual"], label="选择对话风格", value="formal" ) chat_interface = gr.ChatInterface( fn=self.generate_text, additional_inputs=[style_select] ) return interface # 创建应用实例 app = NovelAIApp() interface = app.create_interface() interface.launch()