File size: 1,197 Bytes
bac242d
a367787
b6af10b
bac242d
 
b6af10b
 
81e40a8
b6af10b
 
81e40a8
b6af10b
81e40a8
b6af10b
 
81e40a8
b6af10b
bac242d
b6af10b
bac242d
b6af10b
bac242d
b6af10b
 
 
 
 
 
 
 
 
 
 
 
 
bac242d
b6af10b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os

# 1. 基础配置
app = FastAPI()

# 2. 强制设置缓存路径(解决权限问题)
os.environ["HF_HOME"] = "/app/.cache/huggingface"

# 3. 加载模型(自动缓存到指定路径)
try:
    model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
    tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
except Exception as e:
    raise RuntimeError(f"模型加载失败: {str(e)}")

# 4. 接口定义
@app.post("/detect")
async def detect(code: str):
    try:
        # 简单处理超长输入
        if len(code) > 2000:
            code = code[:2000]
            
        inputs = tokenizer(code, return_tensors="pt", truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
            
        return {
            "label": model.config.id2label[outputs.logits.argmax().item()],
            "score": outputs.logits.softmax(dim=-1).max().item()
        }
        
    except Exception as e:
        return {"error": str(e)}