codebertBase / app.py
Forrest99's picture
Update app.py
d37c72d verified
from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging
# === 初始化配置 ===
app = FastAPI(title="Code Security API")
# 解决跨域问题
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)
# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")
# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
"""健康检查端点"""
return {
"status": "running",
"endpoints": {
"detect": "POST /detect - 代码安全检测",
"specs": "GET /openapi.json - API文档"
}
}
# === 模型加载 ===
try:
logger.info("Loading model from: %s", cache_path)
model = AutoModelForSequenceClassification.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error("Model load failed: %s", str(e))
raise RuntimeError("模型初始化失败")
# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(payload: dict = Body(...)):
"""代码安全检测主接口"""
try:
# 获取 JSON 输入数据
code = payload.get("code", "").strip()
if not code:
return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"}
# 限制代码长度
code = code[:2000] # 截断超长输入
# 模型推理
inputs = tokenizer(
code,
return_tensors="pt",
truncation=True,
padding=True, # 自动选择填充策略
max_length=512
)
with torch.no_grad():
outputs = model(**inputs)
# 结果解析
logits = outputs.logits
label_id = logits.argmax().item()
confidence = logits.softmax(dim=-1)[0][label_id].item()
logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}")
return {
"label": label_id, # 0:安全 1:不安全
"confidence": round(confidence, 4)
}
except Exception as e:
logger.error("Error during model inference: %s", str(e))
return {
"error": str(e),
"tip": "请检查输入代码是否包含非ASCII字符或格式错误"
}