nekoa / app.py
jljiu's picture
Upload 4 files
b95f55d verified
raw
history blame
3.24 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os
from train import ModelTrainer
class NovelAIApp:
def __init__(self):
self.model = None
self.tokenizer = None
self.trainer = None
# 加载系统提示词
with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
self.system_prompts = json.load(f)
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_8bit=True,
device_map="auto"
)
def train_model(self, files):
if not self.trainer:
self.trainer = ModelTrainer(
"CohereForAI/c4ai-command-r-plus-08-2024",
"configs/system_prompts.json"
)
dataset = self.trainer.prepare_dataset(files)
self.trainer.train(dataset)
return "训练完成!"
def generate_text(self, prompt, system_prompt_type="creative"):
if not self.model:
return "请先加载模型!"
system_prompt = self.system_prompts.get(system_prompt_type, self.system_prompts["base_prompt"])
formatted_prompt = f"""<|system|>{system_prompt}</|system|>
<|user|>{prompt}</|user|>
<|assistant|>"""
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
outputs = self.model.generate(
inputs["input_ids"],
max_length=512,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def create_interface(self):
with gr.Blocks() as interface:
gr.Markdown("# 风格化对话助手")
with gr.Tab("模型训练"):
gr.Markdown("""### 上传参考文本
上传文本文件来训练模型学习特定的语言风格。
建议上传具有鲜明语言特色的文本。""")
file_output = gr.File(
file_count="multiple",
label="上传参考文本文件"
)
train_button = gr.Button("开始训练")
train_output = gr.Textbox(label="训练状态")
with gr.Tab("对话"):
gr.Markdown("与助手进行对话,体验风格化的语言表达")
style_select = gr.Dropdown(
choices=["formal", "casual"],
label="选择对话风格",
value="formal"
)
chat_interface = gr.ChatInterface(
fn=self.generate_text,
additional_inputs=[style_select]
)
return interface
# 创建应用实例
app = NovelAIApp()
interface = app.create_interface()
interface.launch()