WooWoof_AI / app.py
larry1129's picture
Update app.py
ee7c5db verified
raw
history blame
2.11 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 定义模型名称(替换为您上传的模型名称)
model_name = "larry1129/WooWoof_AI" # 替换为您的模型名称
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True # 如果你的模型使用自定义代码,请保留此参数
)
# 设置 pad_token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# 切换到评估模式
model.eval()
# 定义提示生成函数
def generate_prompt(instruction, input_text=""):
if input_text:
prompt = f"""### Instruction:
{instruction}
### Input:
{input_text}
### Response:
"""
else:
prompt = f"""### Instruction:
{instruction}
### Response:
"""
return prompt
# 定义生成响应的函数
def generate_response(instruction, input_text):
prompt = generate_prompt(instruction, input_text)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=128,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
# 创建 Gradio 接口
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="请输入指令...", label="Instruction"),
gr.inputs.Textbox(lines=2, placeholder="如果有额外输入,请在此填写...", label="Input (可选)")
],
outputs="text",
title="WooWoof AI 交互式聊天",
description="基于 LLAMA 3.1 的大语言模型,支持指令和可选输入。",
allow_flagging="never"
)
# 启动 Gradio 接口
iface.launch()