reposhiled-7b / app.py
QLWD's picture
Update app.py
4f37da0 verified
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
import gradio as gr
from threading import Thread
# 从环境变量中获取 Hugging Face 模型信息
HF_TOKEN = os.environ.get("HF_TOKEN", None)
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 基础模型
LORA_MODEL_PATH = "QLWD/test-7b" # LoRA 模型路径
# 定义界面标题和描述
TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
DESCRIPTION = f"""
<h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">漏洞检测 微调模型</a></h3>
<center>
<p>测试基础模型 + LoRA 补丁的生成效果。</p>
</center>
"""
PLACEHOLDER = """
<center>
<p>请输入您要分析的代码...</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda"
# 加载tokenizer和基础模型
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
use_fast=False,
trust_remote_code=True,
force_download=True
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
ignore_mismatched_sizes=True,
force_download=True
)
# 加载 LoRA 微调权重
model = PeftModel.from_pretrained(
base_model,
LORA_MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto"
)
def format_chat(system_prompt, history, message):
formatted_chat = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
for prompt, answer in history:
formatted_chat += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>\n"
formatted_chat += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
return formatted_chat
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.3,
max_new_tokens: int = 256,
top_p: float = 1.0,
top_k: int = 20,
repetition_penalty: float = 1.2,
):
print(f'message: {message}')
print(f'history: {history}')
formatted_prompt = format_chat(system_prompt, history, message)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=5000.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
if "<|endoftext|>" in buffer:
yield buffer.split("<|endoftext|>")[0]
break
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
SYSTEM_PROMPT = '''你是一位代码审计和漏洞修复专家,请仔细分析下面提供的代码,扫描并输出所有存在的漏洞和潜在的风险。每个漏洞或风险之间用分隔符 "--------" 隔开,报告内容左对齐。
从高危到低危的顺序来列出漏洞和风险,每个漏洞或风险的格式如下:
- **类型**:明确描述漏洞的类型或名称(如果已经有对应名称),或潜在的风险类型(如资源泄露、边界条件问题等)。
- **风险等级**:根据漏洞或风险的严重性进行评级(如高危、中危、低危)。
- **漏洞/风险描述**:以专业的角度详细解释漏洞的技术细节和成因,或描述潜在的风险。
- **影响**:说明该漏洞或风险可能对系统、数据或用户造成的具体影响。
- **修复建议**:提供修复该漏洞或风险的具体步骤或建议(不是给出修复代码,而是修复的实现方法)。
- **漏洞所在的代码段**:给出代码中存在漏洞的具体位置和代码段(如适用)。
- **修复的代码段**:给出修复漏洞的替换代码段,以便开发者使用(如适用)。
请确保扫描并**输出所有**漏洞和风险,请确保扫描并**输出所有你能够笃定和大概率存在的**漏洞和风险。
分隔符 "--------" 用于每个漏洞或风险之间。'''
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="复制此 Space 进行私有部署", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT,
label="系统提示词",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.1,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=8192,
label="最大生成长度",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top-p",
render=False,
),
gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Top-k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="重复惩罚",
render=False,
),
],
examples=None,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(share=True)