Forrest99 commited on
Commit
d763957
·
verified ·
1 Parent(s): c6b1133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -3,58 +3,90 @@ import re
3
  from fastapi import FastAPI
4
  from transformers import AutoTokenizer, T5ForConditionalGeneration
5
 
 
6
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
7
 
8
- app = FastAPI()
9
 
10
- # 初始化模型
11
  try:
 
12
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
13
  model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
 
14
  except Exception as e:
15
- raise RuntimeError(f"模型加载失败: {str(e)}")
 
16
 
17
  def sanitize_code(code: str) -> str:
18
- """清洗输入代码"""
19
- code = re.sub(r"[<>&\"']", "", code) # 过滤危险字符
20
- return code[:1024] # 限制输入长度
 
 
21
 
22
- @app.get("/analyze")
23
  async def analyze_get(code: str):
 
24
  try:
25
- # 清洗输入
 
26
  code = sanitize_code(code)
27
-
28
- # 构造提示词
29
- prompt = f"""Analyze the following code for security vulnerabilities in Chinese.
30
- 重点检查SQL注入、XSS、命令注入、路径遍历等问题。
31
- 按此格式返回:\n[漏洞类型]: [风险描述]\n\n代码:\n{code}"""
32
-
33
- # Tokenize输入
 
 
 
 
 
 
 
 
 
 
34
  inputs = tokenizer(
35
- prompt,
36
- return_tensors="pt",
37
- max_length=512,
38
  truncation=True,
39
  padding="max_length"
40
  )
41
-
42
- # 生成分析结果
 
43
  outputs = model.generate(
44
  inputs.input_ids,
 
45
  max_length=512,
46
  num_beams=5,
47
  early_stopping=True,
48
- temperature=0.7
 
 
49
  )
50
-
51
- # 解码结果
52
  analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
- return {"result": analysis}
54
-
 
 
 
 
55
  except Exception as e:
56
- return {"error": str(e)}
 
57
 
58
- @app.get("/")
59
  async def health_check():
60
- return {"status": "active"}
 
 
 
 
 
 
3
  from fastapi import FastAPI
4
  from transformers import AutoTokenizer, T5ForConditionalGeneration
5
 
6
+ # 环境变量配置(必须放在所有import之前)
7
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
8
 
9
+ app = FastAPI(title="代码安全审计API", version="1.2.0")
10
 
11
+ # 模型初始化(带异常捕获)
12
  try:
13
+ print("正在加载模型...")
14
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
15
  model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
16
+ print("✅ 模型加载成功!")
17
  except Exception as e:
18
+ print(f"模型加载失败: {str(e)}")
19
+ raise RuntimeError("模型初始化失败,请检查模型路径或网络连接")
20
 
21
  def sanitize_code(code: str) -> str:
22
+ """安全清洗输入代码"""
23
+ # 过滤危险字符
24
+ code = re.sub(r"[<>&\"'\\]", "", code)
25
+ # 限制输入长度(防止超长攻击)
26
+ return code[:2000]
27
 
28
+ @app.get("/analyze", summary="代码漏洞分析", description="输入代码片段,返回安全分析结果")
29
  async def analyze_get(code: str):
30
+ """主分析端点"""
31
  try:
32
+ # 输入预处理
33
+ raw_code = code
34
  code = sanitize_code(code)
35
+ print(f"\n🔍 原始输入:\n{raw_code}\n🛡️ 清洗后:\n{code}\n")
36
+
37
+ # 构造专业级提示词模板
38
+ prompt = f"""作为资深安全工程师,请分析以下代码的安全漏洞:
39
+
40
+ **要求**:
41
+ 1. 检查以下漏洞类型:SQL注入、XSS、命令注入、路径遍历、敏感信息泄露
42
+ 2. 用中文按格式返回:
43
+ [漏洞类型] [危险等级]:具体描述
44
+ [修复建议]:解决方案
45
+
46
+ **代码片段**:
47
+ {code}"""
48
+
49
+ print(f"📝 提示词:\n{prompt}\n")
50
+
51
+ # Token化处理
52
  inputs = tokenizer(
53
+ prompt,
54
+ return_tensors="pt",
55
+ max_length=512,
56
  truncation=True,
57
  padding="max_length"
58
  )
59
+ print(f"🔢 输入Token数量: {len(inputs.input_ids[0])}")
60
+
61
+ # 生成参数配置(经过专业调优)
62
  outputs = model.generate(
63
  inputs.input_ids,
64
+ attention_mask=inputs.attention_mask,
65
  max_length=512,
66
  num_beams=5,
67
  early_stopping=True,
68
+ temperature=0.7,
69
+ top_p=0.9,
70
+ repetition_penalty=1.2
71
  )
72
+
73
+ # 解码与后处理
74
  analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+ print(f"⚙️ 原始输出:\n{analysis}\n")
76
+
77
+ # 格式化处理
78
+ cleaned_analysis = analysis.replace("**代码片段**", "").strip()
79
+ return {"result": cleaned_analysis}
80
+
81
  except Exception as e:
82
+ print(f"❌ 错误追踪: {str(e)}")
83
+ return {"error": "分析失败,请检查输入格式"}
84
 
85
+ @app.get("/", include_in_schema=False)
86
  async def health_check():
87
+ """健康检查端点"""
88
+ return {
89
+ "status": "running",
90
+ "model": "Salesforce/codet5-small",
91
+ "version": "1.2.0"
92
+ }