File size: 1,775 Bytes
bac242d
e28e6dd
a367787
bac242d
338753c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e40a8
e28e6dd
eb9892c
fc986a8
 
e28e6dd
 
 
fc986a8
eb9892c
fc986a8
338753c
 
 
 
 
 
 
 
113ca35
338753c
bac242d
e28e6dd
338753c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import logging
from pathlib import Path

# === 初始化日志 ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === 检查缓存目录权限 ===
def check_permissions():
    cache_path = Path(os.getenv("HF_HOME", ""))
    try:
        cache_path.mkdir(parents=True, exist_ok=True)
        test_file = cache_path / "permission_test.txt"
        test_file.write_text("test")
        test_file.unlink()
        logger.info(f"✅ 缓存目录权限正常: {cache_path}")
    except Exception as e:
        logger.error(f"❌ 缓存目录权限异常: {str(e)}")
        raise RuntimeError(f"Directory permission error: {str(e)}")

check_permissions()

# === FastAPI 配置 ===
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# === 模型加载 ===
try:
    logger.info("🔄 加载模型中...")
    model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
    tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
    logger.info("✅ 模型加载成功")
except Exception as e:
    logger.error(f"❌ 模型加载失败: {str(e)}")
    raise

# === API 接口 ===
@app.post("/detect")
async def detect(code: str):
    inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    return {
        "label": int(outputs.logits.argmax()),
        "score": outputs.logits.softmax(dim=-1).max().item()
    }