from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import os import logging # === 初始化配置 === app = FastAPI(title="Code Security API") # 解决跨域问题 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # === 强制设置缓存路径 === os.environ["HF_HOME"] = "/app/.cache/huggingface" cache_path = os.getenv("HF_HOME") os.makedirs(cache_path, exist_ok=True) # === 日志配置 === logging.basicConfig(level=logging.INFO) logger = logging.getLogger("CodeBERT-API") # === 根路径路由(必须定义)=== @app.get("/") async def read_root(): """健康检查端点""" return { "status": "running", "endpoints": { "detect": "POST /detect - 代码安全检测", "specs": "GET /openapi.json - API文档" } } # === 模型加载 === try: logger.info("Loading model from: %s", cache_path) model = AutoModelForSequenceClassification.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=cache_path ) tokenizer = AutoTokenizer.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=cache_path ) logger.info("Model loaded successfully") except Exception as e: logger.error("Model load failed: %s", str(e)) raise RuntimeError("模型初始化失败") # === 核心检测接口 === @app.post("/detect") async def detect_vulnerability(payload: dict = Body(...)): """代码安全检测主接口""" try: # 获取 JSON 输入数据 code = payload.get("code", "").strip() if not code: return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"} # 限制代码长度 code = code[:2000] # 截断超长输入 # 模型推理 inputs = tokenizer( code, return_tensors="pt", truncation=True, padding=True, # 自动选择填充策略 max_length=512 ) with torch.no_grad(): outputs = model(**inputs) # 结果解析 logits = outputs.logits label_id = logits.argmax().item() confidence = logits.softmax(dim=-1)[0][label_id].item() logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}") return { "label": label_id, # 0:安全 1:不安全 "confidence": round(confidence, 4) } except Exception as e: logger.error("Error during model inference: %s", str(e)) return { "error": str(e), "tip": "请检查输入代码是否包含非ASCII字符或格式错误" }