Forrest99 commited on
Commit
981a5ac
·
verified ·
1 Parent(s): 19776d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -22
app.py CHANGED
@@ -1,35 +1,60 @@
1
  import os
 
2
  from fastapi import FastAPI
3
- from transformers import pipeline
4
 
5
- # 必须在所有导入前设置环境变量
6
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
7
- os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
8
 
9
  app = FastAPI()
10
 
11
- # 模型加载(带错误处理)
12
  try:
13
- analyzer = pipeline(
14
- "text2text-generation",
15
- model="Salesforce/codet5-small",
16
- tokenizer="Salesforce/codet5-small"
17
- )
18
  except Exception as e:
19
  raise RuntimeError(f"模型加载失败: {str(e)}")
20
 
21
- @app.post("/analyze")
22
- async def analyze(code: str):
23
- result = analyzer(
24
- f"Analyze code vulnerabilities:\n{code}",
25
- max_length=512,
26
- num_beams=5
27
- )[0]['generated_text']
28
- return {"result": result}
29
 
30
-
31
- # 添加GET支持,测试模型
32
  @app.get("/analyze")
33
- async def analyze_get(code: str): # 新增GET方法
34
- return await analyze(code)
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
  from fastapi import FastAPI
4
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
5
 
 
6
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
 
7
 
8
  app = FastAPI()
9
 
10
+ # 初始化模型
11
  try:
12
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
13
+ model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
 
 
 
14
  except Exception as e:
15
  raise RuntimeError(f"模型加载失败: {str(e)}")
16
 
17
+ def sanitize_code(code: str) -> str:
18
+ """清洗输入代码"""
19
+ code = re.sub(r"[<>&\"']", "", code) # 过滤危险字符
20
+ return code[:1024] # 限制输入长度
 
 
 
 
21
 
 
 
22
  @app.get("/analyze")
23
+ async def analyze_get(code: str):
24
+ try:
25
+ # 清洗输入
26
+ code = sanitize_code(code)
27
+
28
+ # 构造提示词
29
+ prompt = f"""Analyze the following code for security vulnerabilities in Chinese.
30
+ 重点检查SQL注入、XSS、命令注入、路径遍历等问题。
31
+ 按此格式返回:\n[漏洞类型]: [风险描述]\n\n代码:\n{code}"""
32
+
33
+ # Tokenize输入
34
+ inputs = tokenizer(
35
+ prompt,
36
+ return_tensors="pt",
37
+ max_length=512,
38
+ truncation=True,
39
+ padding="max_length"
40
+ )
41
+
42
+ # 生成分析结果
43
+ outputs = model.generate(
44
+ inputs.input_ids,
45
+ max_length=512,
46
+ num_beams=5,
47
+ early_stopping=True,
48
+ temperature=0.7
49
+ )
50
+
51
+ # 解码结果
52
+ analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return {"result": analysis}
54
+
55
+ except Exception as e:
56
+ return {"error": str(e)}
57
+
58
+ @app.get("/")
59
+ async def health_check():
60
+ return {"status": "active"}