File size: 2,385 Bytes
bac242d
e28e6dd
a367787
0a27391
bac242d
338753c
 
0a27391
 
338753c
0a27391
fc986a8
 
e28e6dd
 
 
fc986a8
eb9892c
0a27391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc986a8
338753c
0a27391
 
 
 
 
 
 
 
 
 
338753c
0a27391
 
113ca35
0a27391
bac242d
0a27391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging

# === 初始化配置 ===
app = FastAPI(title="Code Security API")

# 解决跨域问题
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)

# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")

# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
    """健康检查端点"""
    return {
        "status": "running",
        "endpoints": {
            "detect": "POST /detect - 代码安全检测",
            "specs": "GET /openapi.json - API文档"
        }
    }

# === 模型加载 ===
try:
    logger.info("Loading model from: %s", cache_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "mrm8488/codebert-base-finetuned-detect-insecure-code",
        cache_dir=cache_path
    )
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error("Model load failed: %s", str(e))
    raise RuntimeError("模型初始化失败")

# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(code: str):
    """代码安全检测主接口"""
    try:
        # 输入处理
        code = code[:2000]  # 截断超长输入
        
        # 模型推理
        inputs = tokenizer(
            code,
            return_tensors="pt",
            truncation=True,
            max_length=512
        )
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 结果解析
        label_id = outputs.logits.argmax().item()
        return {
            "label": label_id,  # 0:安全 1:不安全
            "confidence": outputs.logits.softmax(dim=-1)[0][label_id].item()
        }
        
    except Exception as e:
        return {
            "error": str(e),
            "tip": "请检查输入代码是否包含非ASCII字符"
        }