File size: 2,995 Bytes
e6e9cd7
981a5ac
e6e9cd7
981a5ac
08465c2
d763957
7de8bc5
08465c2
d763957
7de8bc5
d763957
7de8bc5
d763957
981a5ac
 
d763957
7de8bc5
d763957
 
08465c2
981a5ac
d763957
 
 
 
 
19776d8
d763957
981a5ac
d763957
981a5ac
d763957
 
981a5ac
d763957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981a5ac
d763957
 
 
981a5ac
 
 
d763957
 
 
981a5ac
 
d763957
981a5ac
 
 
d763957
 
 
981a5ac
d763957
 
981a5ac
d763957
 
 
 
 
 
981a5ac
d763957
 
981a5ac
d763957
981a5ac
d763957
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import re
from fastapi import FastAPI
from transformers import AutoTokenizer, T5ForConditionalGeneration

# 环境变量配置(必须放在所有import之前)
os.environ["HF_HOME"] = "/app/.cache/huggingface"

app = FastAPI(title="代码安全审计API", version="1.2.0")

# 模型初始化(带异常捕获)
try:
    print("正在加载模型...")
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
    model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
    print("✅ 模型加载成功!")
except Exception as e:
    print(f"❌ 模型加载失败: {str(e)}")
    raise RuntimeError("模型初始化失败,请检查模型路径或网络连接")

def sanitize_code(code: str) -> str:
    """安全清洗输入代码"""
    # 过滤危险字符
    code = re.sub(r"[<>&\"'\\]", "", code)
    # 限制输入长度(防止超长攻击)
    return code[:2000]

@app.get("/analyze", summary="代码漏洞分析", description="输入代码片段,返回安全分析结果")
async def analyze_get(code: str):
    """主分析端点"""
    try:
        # 输入预处理
        raw_code = code
        code = sanitize_code(code)
        print(f"\n🔍 原始输入:\n{raw_code}\n🛡️ 清洗后:\n{code}\n")

        # 构造专业级提示词模板
        prompt = f"""作为资深安全工程师,请分析以下代码的安全漏洞:

**要求**:
1. 检查以下漏洞类型:SQL注入、XSS、命令注入、路径遍历、敏感信息泄露
2. 用中文按格式返回:
[漏洞类型] [危险等级]:具体描述
[修复建议]:解决方案

**代码片段**:
{code}"""

        print(f"📝 提示词:\n{prompt}\n")

        # Token化处理
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding="max_length"
        )
        print(f"🔢 输入Token数量: {len(inputs.input_ids[0])}")

        # 生成参数配置(经过专业调优)
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_length=512,
            num_beams=5,
            early_stopping=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2
        )

        # 解码与后处理
        analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"⚙️ 原始输出:\n{analysis}\n")
        
        # 格式化处理
        cleaned_analysis = analysis.replace("**代码片段**", "").strip()
        return {"result": cleaned_analysis}

    except Exception as e:
        print(f"❌ 错误追踪: {str(e)}")
        return {"error": "分析失败,请检查输入格式"}

@app.get("/", include_in_schema=False)
async def health_check():
    """健康检查端点"""
    return {
        "status": "running",
        "model": "Salesforce/codet5-small",
        "version": "1.2.0"
    }