Spaces:
Running
Running
File size: 2,920 Bytes
d37c72d e28e6dd a367787 0a27391 bac242d 338753c 0a27391 338753c 0a27391 fc986a8 e28e6dd fc986a8 eb9892c 0a27391 fc986a8 338753c 0a27391 338753c 0a27391 113ca35 0a27391 bac242d d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d 0a27391 d37c72d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging
# === 初始化配置 ===
app = FastAPI(title="Code Security API")
# 解决跨域问题
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)
# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")
# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
"""健康检查端点"""
return {
"status": "running",
"endpoints": {
"detect": "POST /detect - 代码安全检测",
"specs": "GET /openapi.json - API文档"
}
}
# === 模型加载 ===
try:
logger.info("Loading model from: %s", cache_path)
model = AutoModelForSequenceClassification.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error("Model load failed: %s", str(e))
raise RuntimeError("模型初始化失败")
# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(payload: dict = Body(...)):
"""代码安全检测主接口"""
try:
# 获取 JSON 输入数据
code = payload.get("code", "").strip()
if not code:
return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"}
# 限制代码长度
code = code[:2000] # 截断超长输入
# 模型推理
inputs = tokenizer(
code,
return_tensors="pt",
truncation=True,
padding=True, # 自动选择填充策略
max_length=512
)
with torch.no_grad():
outputs = model(**inputs)
# 结果解析
logits = outputs.logits
label_id = logits.argmax().item()
confidence = logits.softmax(dim=-1)[0][label_id].item()
logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}")
return {
"label": label_id, # 0:安全 1:不安全
"confidence": round(confidence, 4)
}
except Exception as e:
logger.error("Error during model inference: %s", str(e))
return {
"error": str(e),
"tip": "请检查输入代码是否包含非ASCII字符或格式错误"
}
|