|
import gradio as gr
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
import json
|
|
import os
|
|
from train import ModelTrainer
|
|
|
|
class NovelAIApp:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.trainer = None
|
|
|
|
|
|
with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
|
|
self.system_prompts = json.load(f)
|
|
|
|
def load_model(self, model_path):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
load_in_8bit=True,
|
|
device_map="auto"
|
|
)
|
|
|
|
def train_model(self, files):
|
|
if not self.trainer:
|
|
self.trainer = ModelTrainer(
|
|
"CohereForAI/c4ai-command-r-plus-08-2024",
|
|
"configs/system_prompts.json"
|
|
)
|
|
|
|
dataset = self.trainer.prepare_dataset(files)
|
|
self.trainer.train(dataset)
|
|
return "训练完成!"
|
|
|
|
def generate_text(self, prompt, system_prompt_type="creative"):
|
|
if not self.model:
|
|
return "请先加载模型!"
|
|
|
|
system_prompt = self.system_prompts.get(system_prompt_type, self.system_prompts["base_prompt"])
|
|
|
|
formatted_prompt = f"""<|system|>{system_prompt}</|system|>
|
|
<|user|>{prompt}</|user|>
|
|
<|assistant|>"""
|
|
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
|
|
outputs = self.model.generate(
|
|
inputs["input_ids"],
|
|
max_length=512,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
repetition_penalty=1.1,
|
|
num_return_sequences=1,
|
|
pad_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
def create_interface(self):
|
|
with gr.Blocks() as interface:
|
|
gr.Markdown("# 风格化对话助手")
|
|
|
|
with gr.Tab("模型训练"):
|
|
gr.Markdown("""### 上传参考文本
|
|
上传文本文件来训练模型学习特定的语言风格。
|
|
建议上传具有鲜明语言特色的文本。""")
|
|
|
|
file_output = gr.File(
|
|
file_count="multiple",
|
|
label="上传参考文本文件"
|
|
)
|
|
train_button = gr.Button("开始训练")
|
|
train_output = gr.Textbox(label="训练状态")
|
|
|
|
with gr.Tab("对话"):
|
|
gr.Markdown("与助手进行对话,体验风格化的语言表达")
|
|
style_select = gr.Dropdown(
|
|
choices=["formal", "casual"],
|
|
label="选择对话风格",
|
|
value="formal"
|
|
)
|
|
chat_interface = gr.ChatInterface(
|
|
fn=self.generate_text,
|
|
additional_inputs=[style_select]
|
|
)
|
|
|
|
return interface
|
|
|
|
|
|
app = NovelAIApp()
|
|
interface = app.create_interface()
|
|
interface.launch() |