File size: 2,920 Bytes
d37c72d
e28e6dd
a367787
0a27391
bac242d
338753c
 
0a27391
 
338753c
0a27391
fc986a8
 
e28e6dd
 
 
fc986a8
eb9892c
0a27391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc986a8
338753c
0a27391
 
 
 
 
 
 
 
 
 
338753c
0a27391
 
113ca35
0a27391
bac242d
d37c72d
0a27391
 
d37c72d
 
 
 
 
 
 
0a27391
 
 
 
 
 
 
d37c72d
0a27391
 
d37c72d
0a27391
 
d37c72d
0a27391
d37c72d
 
 
 
 
 
0a27391
 
d37c72d
0a27391
 
 
d37c72d
0a27391
 
d37c72d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from fastapi import FastAPI, Body
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging

# === 初始化配置 ===
app = FastAPI(title="Code Security API")

# 解决跨域问题
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)

# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")

# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
    """健康检查端点"""
    return {
        "status": "running",
        "endpoints": {
            "detect": "POST /detect - 代码安全检测",
            "specs": "GET /openapi.json - API文档"
        }
    }

# === 模型加载 ===
try:
    logger.info("Loading model from: %s", cache_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error("Model load failed: %s", str(e))
    raise RuntimeError("模型初始化失败")

# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(payload: dict = Body(...)):
    """代码安全检测主接口"""
    try:
        # 获取 JSON 输入数据
        code = payload.get("code", "").strip()

        if not code:
            return {"error": "代码内容为空", "tip": "请提供有效的代码字符串"}

        # 限制代码长度
        code = code[:2000]  # 截断超长输入
        
        # 模型推理
        inputs = tokenizer(
            code,
            return_tensors="pt",
            truncation=True,
            padding=True,  # 自动选择填充策略
            max_length=512
        )

        with torch.no_grad():
            outputs = model(**inputs)

        # 结果解析
        logits = outputs.logits
        label_id = logits.argmax().item()
        confidence = logits.softmax(dim=-1)[0][label_id].item()

        logger.info(f"Code analyzed. Logits: {logits.tolist()}, Prediction: {label_id}, Confidence: {confidence:.4f}")

        return {
            "label": label_id,  # 0:安全 1:不安全
            "confidence": round(confidence, 4)
        }
        
    except Exception as e:
        logger.error("Error during model inference: %s", str(e))
        return {
            "error": str(e),
            "tip": "请检查输入代码是否包含非ASCII字符或格式错误"
        }