import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import json import os from train import ModelTrainer class NovelAIApp: def __init__(self): self.model = None self.tokenizer = None self.trainer = None # 加载系统提示词 with open('configs/system_prompts.json', 'r', encoding='utf-8') as f: self.system_prompts = json.load(f) # 初始化默认的情境 self.current_mood = "暗示" def load_model(self, model_path): self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True ) self.model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, load_in_8bit=True, device_map="auto" ) def train_model(self, files): if not self.trainer: self.trainer = ModelTrainer( "THUDM/chatglm2-6b", "configs/system_prompts.json" ) dataset = self.trainer.prepare_dataset(files) self.trainer.train(dataset) return "训练完成!" def generate_text(self, message, history): """修改后的生成文本方法,适配 ChatInterface""" if not self.model: return "请先加载模型!" system_prompt = self.system_prompts.get("base_prompt") # 构建完整的对话历史 full_history = "" for msg in history: full_history += f"<|user|>{msg[0]}\n<|assistant|>{msg[1]}\n" formatted_prompt = f"""<|system|>{system_prompt} {full_history}<|user|>{message} <|assistant|>""" inputs = self.tokenizer(formatted_prompt, return_tensors="pt") outputs = self.model.generate( inputs["input_ids"], max_length=1024, temperature=0.7, top_p=0.9, repetition_penalty=1.1 ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # 提取助手的回复部分 response = response.split("<|assistant|>")[-1].strip() return response def create_interface(self): """创建 Gradio 界面""" with gr.Blocks() as interface: gr.Markdown("# 猫娘对话助手") with gr.Tab("模型训练"): file_output = gr.File( file_count="multiple", label="上传小说文本文件" ) train_button = gr.Button("开始训练") train_output = gr.Textbox(label="训练状态") train_button.click( fn=self.train_model, inputs=[file_output], outputs=[train_output] ) with gr.Tab("对话"): chatbot = gr.ChatInterface( fn=self.generate_text, title="与猫娘对话", description="来和可爱的猫娘聊天吧~", theme="soft", examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"], cache_examples=False, type="messages" ) return interface # 创建应用实例 app = NovelAIApp() interface = app.create_interface() # 修改 launch 参数 interface.launch( server_name="0.0.0.0", # 允许外部访问 share=True, # 创建公共链接 ssl_verify=False # 禁用 SSL 验证 )