Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import os | |
import logging | |
# === 初始化配置 === | |
app = FastAPI(title="Code Security API") | |
# 解决跨域问题 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# === 强制设置缓存路径 === | |
os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
cache_path = os.getenv("HF_HOME") | |
os.makedirs(cache_path, exist_ok=True) | |
# === 日志配置 === | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("CodeBERT-API") | |
# === 根路径路由(必须定义)=== | |
async def read_root(): | |
"""健康检查端点""" | |
return { | |
"status": "running", | |
"endpoints": { | |
"detect": "POST /detect - 代码安全检测", | |
"specs": "GET /openapi.json - API文档" | |
} | |
} | |
# === 模型加载 === | |
try: | |
logger.info("Loading model from: %s", cache_path) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"mrm8488/codebert-base-finetuned-detect-insecure-code", | |
cache_dir=cache_path | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mrm8488/codebert-base-finetuned-detect-insecure-code", | |
cache_dir=cache_path | |
) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error("Model load failed: %s", str(e)) | |
raise RuntimeError("模型初始化失败") | |
# === 核心检测接口 === | |
async def detect_vulnerability(payload: dict = Body(...)): | |
"""代码安全检测主接口""" | |
try: | |
# 获取 JSON 输入数据 | |
code = payload.get("code", "").strip() | |
if not code: | |
return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"} | |
# 限制代码长度 | |
code = code[:2000] # 截断超长输入 | |
# 模型推理 | |
inputs = tokenizer( | |
code, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, # 自动选择填充策略 | |
max_length=512 | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# 结果解析 | |
logits = outputs.logits | |
label_id = logits.argmax().item() | |
confidence = logits.softmax(dim=-1)[0][label_id].item() | |
logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}") | |
return { | |
"label": label_id, # 0:安全 1:不安全 | |
"confidence": round(confidence, 4) | |
} | |
except Exception as e: | |
logger.error("Error during model inference: %s", str(e)) | |
return { | |
"error": str(e), | |
"tip": "请检查输入代码是否包含非ASCII字符或格式错误" | |
} | |