File size: 1,732 Bytes
e6e9cd7
981a5ac
e6e9cd7
981a5ac
08465c2
7de8bc5
08465c2
e6e9cd7
7de8bc5
981a5ac
7de8bc5
981a5ac
 
7de8bc5
 
08465c2
981a5ac
 
 
 
19776d8
 
981a5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import re
from fastapi import FastAPI
from transformers import AutoTokenizer, T5ForConditionalGeneration

os.environ["HF_HOME"] = "/app/.cache/huggingface"

app = FastAPI()

# 初始化模型
try:
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
    model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
except Exception as e:
    raise RuntimeError(f"模型加载失败: {str(e)}")

def sanitize_code(code: str) -> str:
    """清洗输入代码"""
    code = re.sub(r"[<>&\"']", "", code)  # 过滤危险字符
    return code[:1024]  # 限制输入长度

@app.get("/analyze")
async def analyze_get(code: str):
    try:
        # 清洗输入
        code = sanitize_code(code)
        
        # 构造提示词
        prompt = f"""Analyze the following code for security vulnerabilities in Chinese.
重点检查SQL注入、XSS、命令注入、路径遍历等问题。
按此格式返回:\n[漏洞类型]: [风险描述]\n\n代码:\n{code}"""
        
        # Tokenize输入
        inputs = tokenizer(
            prompt, 
            return_tensors="pt", 
            max_length=512, 
            truncation=True,
            padding="max_length"
        )
        
        # 生成分析结果
        outputs = model.generate(
            inputs.input_ids,
            max_length=512,
            num_beams=5,
            early_stopping=True,
            temperature=0.7
        )
        
        # 解码结果
        analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"result": analysis}
    
    except Exception as e:
        return {"error": str(e)}

@app.get("/")
async def health_check():
    return {"status": "active"}