codebertBase / app.py
Forrest99's picture
Update app.py
0a27391 verified
raw
history blame
2.39 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
import logging
# === 初始化配置 ===
app = FastAPI(title="Code Security API")
# 解决跨域问题
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# === 强制设置缓存路径 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
cache_path = os.getenv("HF_HOME")
os.makedirs(cache_path, exist_ok=True)
# === 日志配置 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("CodeBERT-API")
# === 根路径路由(必须定义)===
@app.get("/")
async def read_root():
"""健康检查端点"""
return {
"status": "running",
"endpoints": {
"detect": "POST /detect - 代码安全检测",
"specs": "GET /openapi.json - API文档"
}
}
# === 模型加载 ===
try:
logger.info("Loading model from: %s", cache_path)
model = AutoModelForSequenceClassification.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/codebert-base-finetuned-detect-insecure-code",
cache_dir=cache_path
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error("Model load failed: %s", str(e))
raise RuntimeError("模型初始化失败")
# === 核心检测接口 ===
@app.post("/detect")
async def detect_vulnerability(code: str):
"""代码安全检测主接口"""
try:
# 输入处理
code = code[:2000] # 截断超长输入
# 模型推理
inputs = tokenizer(
code,
return_tensors="pt",
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model(**inputs)
# 结果解析
label_id = outputs.logits.argmax().item()
return {
"label": label_id, # 0:安全 1:不安全
"confidence": outputs.logits.softmax(dim=-1)[0][label_id].item()
}
except Exception as e:
return {
"error": str(e),
"tip": "请检查输入代码是否包含非ASCII字符"
}